diff --git a/.gitignore b/.gitignore index 0b2856cc..1b2cf071 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ upload *.exe *.db build -*.db-journal \ No newline at end of file +*.db-journal +logs \ No newline at end of file diff --git a/README.md b/README.md index 45c8b603..02412ec8 100644 --- a/README.md +++ b/README.md @@ -71,10 +71,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [360 智脑](https://ai.360.cn) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) + + [x] [CloseAI](https://console.closeai-asia.com/r/2412) + [x] [API2D](https://api2d.com/r/197971) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) - + [x] [CloseAI](https://console.closeai-asia.com/r/2412) + [x] 自定义渠道:例如各种未收录的第三方代理服务 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 @@ -109,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 +如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 + 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 @@ -209,6 +211,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 +#### QChatGPT - QQ机器人 +项目主页:https://github.com/RockChinQ/QChatGPT + +根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 + +可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 + ### 部署到第三方平台
部署到 Sealos @@ -275,8 +284,9 @@ graph LR 不加的话将会使用负载均衡的方式使用多个渠道。 ### 环境变量 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 + 例子:`SESSION_SECRET=random_string` 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 @@ -303,11 +313,19 @@ graph LR + 例子:`CHANNEL_TEST_FREQUENCY=1440` 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` +10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`BATCH_UPDATE_ENABLED=true` + + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 +11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + + 例子:`BATCH_UPDATE_INTERVAL=5` +12. 请求频率限制: + + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 + 例子:`--port 3000` -2. `--log-dir `: 指定日志文件夹,如果没有设置,日志将不会被保存。 +2. `--log-dir `: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 + 例子:`--log-dir ./logs` 3. `--version`: 打印系统版本号并退出。 4. `--help`: 查看命令的使用帮助和参数说明。 @@ -339,6 +357,7 @@ https://openai.justsong.cn 5. ChatGPT Next Web 报错:`Failed to fetch` + 部署的时候不要设置 `BASE_URL`。 + 检查你的接口地址和 API Key 有没有填对。 + + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 6. 报错:`当前分组负载已饱和,请稍后再试` + 上游通道 429 了。 @@ -352,4 +371,4 @@ https://openai.justsong.cn 同样适用于基于本项目的二开项目。 -依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 \ No newline at end of file +依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 diff --git a/common/constants.go b/common/constants.go index e5211e3d..794a795f 100644 --- a/common/constants.go +++ b/common/constants.go @@ -94,6 +94,13 @@ var RequestInterval = time.Duration(requestInterval) * time.Second var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY +var BatchUpdateEnabled = false +var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + const ( RoleGuestUser = 0 RoleCommonUser = 1 @@ -111,10 +118,10 @@ var ( // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = 180 + GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = 60 + GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -154,49 +161,53 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeAPI2D = 2 + ChannelTypeAzure = 3 + ChannelTypeCloseAI = 4 + ChannelTypeOpenAISB = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 + ChannelTypeAIProxyLibrary = 21 + ChannelTypeFastGPT = 22 ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 } diff --git a/common/init.go b/common/init.go index 0f22c69b..1e9c85ce 100644 --- a/common/init.go +++ b/common/init.go @@ -12,7 +12,7 @@ var ( Port = flag.Int("port", 3000, "the listening port") PrintVersion = flag.Bool("version", false, "print version and exit") PrintHelp = flag.Bool("help", false, "print help and exit") - LogDir = flag.String("log-dir", "", "specify the log directory") + LogDir = flag.String("log-dir", "./logs", "specify the log directory") ) func printHelp() { diff --git a/common/logger.go b/common/logger.go index 3658dbdb..61627217 100644 --- a/common/logger.go +++ b/common/logger.go @@ -1,29 +1,47 @@ package common import ( + "context" "fmt" "github.com/gin-gonic/gin" "io" "log" "os" "path/filepath" + "sync" "time" ) -func SetupGinLog() { +const ( + loggerINFO = "INFO" + loggerWarn = "WARN" + loggerError = "ERR" +) + +const maxLogCount = 1000000 + +var logCount int +var setupLogLock sync.Mutex +var setupLogWorking bool + +func SetupLogger() { if *LogDir != "" { - commonLogPath := filepath.Join(*LogDir, "common.log") - errorLogPath := filepath.Join(*LogDir, "error.log") - commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + ok := setupLogLock.TryLock() + if !ok { + log.Println("setup log is already working") + return + } + defer func() { + setupLogLock.Unlock() + setupLogWorking = false + }() + logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") } - errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) } } @@ -37,6 +55,36 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } +func LogInfo(ctx context.Context, msg string) { + logHelper(ctx, loggerINFO, msg) +} + +func LogWarn(ctx context.Context, msg string) { + logHelper(ctx, loggerWarn, msg) +} + +func LogError(ctx context.Context, msg string) { + logHelper(ctx, loggerError, msg) +} + +func logHelper(ctx context.Context, level string, msg string) { + writer := gin.DefaultErrorWriter + if level == loggerINFO { + writer = gin.DefaultWriter + } + id := ctx.Value(RequestIdKey) + now := time.Now() + _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) + logCount++ // we don't need accurate count, so no lock here + if logCount > maxLogCount && !setupLogWorking { + logCount = 0 + setupLogWorking = true + go func() { + SetupLogger() + }() + } +} + func FatalLog(v ...any) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) diff --git a/common/model-ratio.go b/common/model-ratio.go index 70758805..eeb23e07 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag - "qwen-plus-v1": 0.5715, // Same as above - "SparkDesk": 0.8572, // TBD + "qwen-v1": 0.8572, // ¥0.012 / 1k tokens + "qwen-plus-v1": 1, // ¥0.014 / 1k tokens + "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens diff --git a/common/utils.go b/common/utils.go index bb9b7e0c..ab901b77 100644 --- a/common/utils.go +++ b/common/utils.go @@ -171,6 +171,11 @@ func GetTimestamp() int64 { return time.Now().Unix() } +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + func Max(a int, b int) int { if a >= b { return a @@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int { } return num } + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} diff --git a/controller/billing.go b/controller/billing.go index 79eae1e2..42e86aea 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) { if err != nil { openAIError := OpenAIError{ Message: err.Error(), - Type: "one_api_error", + Type: "upstream_error", } c.JSON(200, gin.H{ "error": openAIError, diff --git a/controller/channel-test.go b/controller/channel-test.go index 686521ef..8c7e6f0d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,7 +14,7 @@ import ( "time" ) -func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { +func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { switch channel.Type { case common.ChannelTypePaLM: fallthrough @@ -32,6 +32,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: request.Model = "gpt-35-turbo" + defer func() { + if err != nil { + err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") + } + }() default: request.Model = "gpt-3.5-turbo" } diff --git a/controller/channel.go b/controller/channel.go index 8afc0eed..50b2b5f6 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) { } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") - channels := make([]model.Channel, 0) + channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue diff --git a/controller/github.go b/controller/github.go index e1c64130..ee995379 100644 --- a/controller/github.go +++ b/controller/github.go @@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { func GitHubOAuth(c *gin.Context) { session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } username := session.Get("username") if username != nil { GitHubBind(c) @@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) { }) return } + +func GenerateOAuthCode(c *gin.Context) { + session := sessions.Default(c) + state := common.GetRandomString(12) + session.Set("oauth_state", state) + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": state, + }) +} diff --git a/controller/log.go b/controller/log.go index df808d8d..b65867fe 100644 --- a/controller/log.go +++ b/controller/log.go @@ -2,6 +2,7 @@ package controller import ( "github.com/gin-gonic/gin" + "net/http" "one-api/common" "one-api/model" "strconv" @@ -21,17 +22,18 @@ func GetAllLogs(c *gin.Context) { channel, _ := strconv.Atoi(c.Query("channel")) logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func GetUserLogs(c *gin.Context) { @@ -47,34 +49,36 @@ func GetUserLogs(c *gin.Context) { modelName := c.Query("model_name") logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func SearchAllLogs(c *gin.Context) { keyword := c.Query("keyword") logs, err := model.SearchAllLogs(keyword) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func SearchUserLogs(c *gin.Context) { @@ -82,17 +86,18 @@ func SearchUserLogs(c *gin.Context) { userId := c.GetInt("id") logs, err := model.SearchUserLogs(userId, keyword) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func GetLogsStat(c *gin.Context) { @@ -105,7 +110,7 @@ func GetLogsStat(c *gin.Context) { channel, _ := strconv.Atoi(c.Query("channel")) quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ @@ -113,6 +118,7 @@ func GetLogsStat(c *gin.Context) { //"token": tokenNum, }, }) + return } func GetLogsSelfStat(c *gin.Context) { @@ -125,7 +131,7 @@ func GetLogsSelfStat(c *gin.Context) { channel, _ := strconv.Atoi(c.Query("channel")) quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ @@ -133,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) { //"token": tokenNum, }, }) + return +} + +func DeleteHistoryLogs(c *gin.Context) { + targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64) + if targetTimestamp == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "target timestamp is required", + }) + return + } + count, err := model.DeleteOldLog(targetTimestamp) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": count, + }) + return } diff --git a/controller/model.go b/controller/model.go index 88f95f7b..637ebe10 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,15 @@ func init() { Root: "qwen-plus-v1", Parent: nil, }, + { + Id: "text-embedding-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "text-embedding-v1", + Parent: nil, + }, { Id: "SparkDesk", Object: "model", diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go new file mode 100644 index 00000000..d0159ce8 --- /dev/null +++ b/controller/relay-aiproxy.go @@ -0,0 +1,220 @@ +package controller + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strconv" + "strings" +) + +// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 + +type AIProxyLibraryRequest struct { + Model string `json:"model"` + Query string `json:"query"` + LibraryId string `json:"libraryId"` + Stream bool `json:"stream"` +} + +type AIProxyLibraryError struct { + ErrCode int `json:"errCode"` + Message string `json:"message"` +} + +type AIProxyLibraryDocument struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type AIProxyLibraryResponse struct { + Success bool `json:"success"` + Answer string `json:"answer"` + Documents []AIProxyLibraryDocument `json:"documents"` + AIProxyLibraryError +} + +type AIProxyLibraryStreamResponse struct { + Content string `json:"content"` + Finish bool `json:"finish"` + Model string `json:"model"` + Documents []AIProxyLibraryDocument `json:"documents"` +} + +func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { + query := "" + if len(request.Messages) != 0 { + query = request.Messages[len(request.Messages)-1].Content + } + return &AIProxyLibraryRequest{ + Model: request.Model, + Stream: request.Stream, + Query: query, + } +} + +func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { + if len(documents) == 0 { + return "" + } + content := "\n\n参考文档:\n" + for i, document := range documents { + content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) + } + return content +} + +func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { + content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: content, + }, + FinishReason: "stop", + } + fullTextResponse := OpenAITextResponse{ + Id: common.GetUUID(), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + } + return &fullTextResponse +} + +func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = aiProxyDocuments2Markdown(documents) + choice.FinishReason = &stopFinishReason + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = response.Content + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: response.Model, + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + var documents []AIProxyLibraryDocument + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var AIProxyLibraryResponse AIProxyLibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + response := documentsAIProxyLibrary(documents) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var AIProxyLibraryResponse AIProxyLibraryResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if AIProxyLibraryResponse.ErrCode != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: AIProxyLibraryResponse.Message, + Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), + Code: AIProxyLibraryResponse.ErrCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 9dca9a89..50dc743c 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -35,6 +35,29 @@ type AliChatRequest struct { Parameters AliParameters `json:"parameters,omitempty"` } +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + type AliError struct { Code string `json:"code"` Message string `json:"message"` @@ -44,6 +67,7 @@ type AliError struct { type AliUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` } type AliOutput struct { @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { } } +func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { + return &AliEmbeddingRequest{ + Model: "text-embedding-v1", + Input: struct { + Texts []string `json:"texts"` + }{ + Texts: request.ParseInput(), + }, + } +} + +func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var aliResponse AliEmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&aliResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Code != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Model: "text-embedding-v1", + Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + } + + for _, item := range response.Output.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { choice := OpenAITextResponseChoice{ Index: 0, diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 1ebb751a..e6f54f01 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -92,7 +93,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var audioResponse AudioResponse - defer func() { + defer func(ctx context.Context) { go func() { quota := countTokenText(audioResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota @@ -107,13 +108,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } }() - }() + }(c.Request.Context()) responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 39f31a9a..ed08ac04 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom } func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - baiduEmbeddingRequest := BaiduEmbeddingRequest{ - Input: nil, + return &BaiduEmbeddingRequest{ + Input: request.ParseInput(), } - switch request.Input.(type) { - case string: - baiduEmbeddingRequest.Input = []string{request.Input.(string)} - case []any: - for _, item := range request.Input.([]any) { - if str, ok := item.(string); ok { - baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) - } - } - } - return &baiduEmbeddingRequest } func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { diff --git a/controller/relay-image.go b/controller/relay-image.go index ea1e1897..fb30895c 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -125,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var textResponse ImageResponse - defer func() { + defer func(ctx context.Context) { if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { @@ -138,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } } - }() + }(c.Request.Context()) if consumeQuota { responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-text.go b/controller/relay-text.go index 43366da4..5a5f355b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -22,6 +23,7 @@ const ( APITypeZhipu APITypeAli APITypeXunfei + APITypeAIProxyLibrary ) var httpClient *http.Client @@ -105,6 +107,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAli case common.ChannelTypeXunfei: apiType = APITypeXunfei + case common.ChannelTypeAIProxyLibrary: + apiType = APITypeAIProxyLibrary } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -172,6 +176,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } + case APITypeAIProxyLibrary: + fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } var promptTokens int var completionTokens int @@ -203,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 + common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } if consumeQuota && preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) @@ -258,8 +268,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeAli: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err := json.Marshal(aliRequest) + var jsonStr []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliEmbeddingRequest) + default: + aliRequest := requestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliRequest) + } + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + case APITypeAIProxyLibrary: + aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) + aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + jsonStr, err := json.Marshal(aiProxyLibraryRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } @@ -303,6 +329,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + default: + req.Header.Set("Authorization", "Bearer "+apiKey) } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) @@ -322,6 +350,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(c.Request.Context()) + } return relayErrorHandler(resp) } } @@ -329,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { var textResponse TextResponse tokenName := c.GetString("token_name") - defer func() { + defer func(ctx context.Context) { // c.Writer.Flush() go func() { if consumeQuota { @@ -352,22 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + common.LogError(ctx, "error update user quota cache: "+err.Error()) } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) } } }() - }() + }(c.Request.Context()) switch apiType { case APITypeOpenAI: if isStream { @@ -488,7 +524,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := aliHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + err, usage = aliHandler(c, resp) + } if err != nil { return err } @@ -498,14 +541,29 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return nil } case APITypeXunfei: + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + var err *OpenAIErrorWithStatusCode + var usage *Usage if isStream { - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + } else { + err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) + } + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + case APITypeAIProxyLibrary: + if isStream { + err, usage := aiProxyLibraryStreamHandler(c, resp) if err != nil { return err } @@ -514,7 +572,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) + err, usage := aiProxyLibraryHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 9010d275..3d5948fc 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -146,7 +146,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr StatusCode: resp.StatusCode, OpenAIError: OpenAIError{ Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "one_api_error", + Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), }, diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 3b6fe5a0..ff6bf065 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, + FinishReason: stopFinishReason, } fullTextResponse := OpenAITextResponse{ Object: "chat.completion", @@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { } func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + } + setEventStreamHeaders(c) var usage Usage - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") + c.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return nil, &usage +} + +func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - if apiVersion == "" { - apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + var usage Usage + var content string + var xunfeiResponse XunfeiChatResponse + stop := false + for !stop { + select { + case xunfeiResponse = <-dataChan: + content += xunfeiResponse.Payload.Choices.Text[0].Content + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + case stop = <-stopChan: + } } - domain := "general" - if apiVersion == "v2.1" { - domain = "generalv2" + + xunfeiResponse.Payload.Choices.Text[0].Content = content + + response := responseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } - hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) + c.Writer.Header().Set("Content-Type", "application/json") + _, _ = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) + conn, resp, err := d.Dial(authUrl, nil) if err != nil || resp.StatusCode != 101 { - return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil + return nil, nil, err } data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { - return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil + return nil, nil, err } + dataChan := make(chan XunfeiChatResponse) stopChan := make(chan bool) go func() { @@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId } stopChan <- true }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case xunfeiResponse := <-dataChan: - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := streamResponseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - return nil, &usage + + return dataChan, stopChan, nil } -func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var xunfeiResponse XunfeiChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil +func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) } - err = json.Unmarshal(responseBody, &xunfeiResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + domain := "general" + if apiVersion == "v2.1" { + domain = "generalv2" } - if xunfeiResponse.Header.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: xunfeiResponse.Header.Message, - Type: "xunfei_error", - Param: "", - Code: xunfeiResponse.Header.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + return domain, authUrl } diff --git a/controller/relay.go b/controller/relay.go index 056d42d3..1926110e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct { Functions any `json:"functions,omitempty"` } +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -177,6 +196,7 @@ func Relay(c *gin.Context) { err = relayTextHelper(c, relayMode) } if err != nil { + requestId := c.GetString(common.RequestIdKey) retryTimesStr := c.Query("retry") retryTimes, _ := strconv.Atoi(retryTimesStr) if retryTimesStr == "" { @@ -188,12 +208,13 @@ func Relay(c *gin.Context) { if err.StatusCode == http.StatusTooManyRequests { err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" } + err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) c.JSON(err.StatusCode, gin.H{ "error": err.OpenAIError, }) } channelId := c.GetInt("channel_id") - common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { channelId := c.GetInt("channel_id") diff --git a/i18n/en.json b/i18n/en.json index aed65979..9b2ca4c8 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -523,5 +523,6 @@ "按照如下格式输入:": "Enter in the following format:", "模型版本": "Model version", "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", - "点击查看": "click to view" + "点击查看": "click to view", + "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" } diff --git a/main.go b/main.go index 9fb0a73e..e8ef4c20 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,7 @@ var buildFS embed.FS var indexPage []byte func main() { - common.SetupGinLog() + common.SetupLogger() common.SysLog("One API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) @@ -77,14 +77,20 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { + common.BatchUpdateEnabled = true + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + model.InitBatchUpdater() + } controller.InitTokenEncoders() // Initialize HTTP server - server := gin.Default() + server := gin.New() + server.Use(gin.Recovery()) // This will cause SSE not to work!!! //server.Use(gzip.Gzip(gzip.DefaultCompression)) - server.Use(middleware.CORS()) - + server.Use(middleware.RequestId()) + middleware.SetUpLogger(server) // Initialize session store store := cookie.NewStore([]byte(common.SessionSecret)) server.Use(sessions.Sessions("session", store)) diff --git a/middleware/auth.go b/middleware/auth.go index 060e005c..dfbc7dbd 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } - if !model.CacheIsUserEnabled(token.UserId) { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "用户已被封禁", - "type": "one_api_error", - }, - }) - c.Abort() + userEnabled, err := model.IsUserEnabled(token.UserId) + if err != nil { + abortWithMessage(c, http.StatusInternalServerError, err.Error()) + return + } + if !userEnabled { + abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } c.Set("id", token.UserId) @@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) } else { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "普通用户不支持指定渠道", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 93827c95..ab374a85 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } channel, err = model.GetChannelById(id, true) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } if channel.Status != common.ChannelStatusEnabled { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "该渠道已被禁用", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { @@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的请求", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的请求") return } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { @@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "message": message, - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusServiceUnavailable, message) return } } @@ -115,8 +85,13 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.ModelMapping) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) - if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { + switch channel.Type { + case common.ChannelTypeAzure: c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + case common.ChannelTypeAIProxyLibrary: + c.Set("library_id", channel.Other) } c.Next() } diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 00000000..02f2e0a9 --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func SetUpLogger(server *gin.Engine) { + server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { + var requestID string + if param.Keys != nil { + requestID = param.Keys[common.RequestIdKey].(string) + } + return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", + param.TimeStamp.Format("2006/01/02 - 15:04:05"), + requestID, + param.StatusCode, + param.Latency, + param.ClientIP, + param.Method, + param.Path, + ) + })) +} diff --git a/middleware/request-id.go b/middleware/request-id.go new file mode 100644 index 00000000..e623be7a --- /dev/null +++ b/middleware/request-id.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "context" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func RequestId() func(c *gin.Context) { + return func(c *gin.Context) { + id := common.GetTimeString() + common.GetRandomString(8) + c.Set(common.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + c.Request = c.Request.WithContext(ctx) + c.Header(common.RequestIdKey, id) + c.Next() + } +} diff --git a/middleware/utils.go b/middleware/utils.go new file mode 100644 index 00000000..536125cc --- /dev/null +++ b/middleware/utils.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "one-api/common" +) + +func abortWithMessage(c *gin.Context, statusCode int, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "type": "one_api_error", + }, + }) + c.Abort() + common.LogError(c.Request.Context(), message) +} diff --git a/model/cache.go b/model/cache.go index 631ef49a..1b547842 100644 --- a/model/cache.go +++ b/model/cache.go @@ -104,23 +104,28 @@ func CacheDecreaseUserQuota(id int, quota int) error { return err } -func CacheIsUserEnabled(userId int) bool { +func CacheIsUserEnabled(userId int) (bool, error) { if !common.RedisEnabled { return IsUserEnabled(userId) } enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) - if err != nil { - status := common.UserStatusDisabled - if IsUserEnabled(userId) { - status = common.UserStatusEnabled - } - enabled = fmt.Sprintf("%d", status) - err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) - } + if err == nil { + return enabled == "1", nil } - return enabled == "1" + + userEnabled, err := IsUserEnabled(userId) + if err != nil { + return false, err + } + enabled = "0" + if userEnabled { + enabled = "1" + } + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user enabled error: " + err.Error()) + } + return userEnabled, err } var group2model2channels map[string]map[string][]*Channel diff --git a/model/channel.go b/model/channel.go index 2da210fa..d146193b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -142,6 +142,14 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) + return + } + updateChannelUsedQuota(id, quota) +} + +func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/log.go b/model/log.go index 53e21d22..1c0a2dc6 100644 --- a/model/log.go +++ b/model/log.go @@ -1,6 +1,8 @@ package model import ( + "context" + "fmt" "gorm.io/gorm" "one-api/common" ) @@ -45,7 +47,9 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + +func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } @@ -64,7 +68,7 @@ func RecordConsumeLog(userId int, channelId int, promptTokens int, completionTok } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + common.LogError(ctx, "failed to record log: "+err.Error()) } } @@ -174,3 +178,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa tx.Where("type = ?", LogTypeConsume).Scan(&token) return token } + +func DeleteOldLog(targetTimestamp int64) (int64, error) { + result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + return result.RowsAffected, result.Error +} diff --git a/model/token.go b/model/token.go index 7cd226c6..0fa984d3 100644 --- a/model/token.go +++ b/model/token.go @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err == nil { + if token.Status == common.TokenStatusExhausted { + return nil, errors.New("该令牌额度已用尽") + } else if token.Status == common.TokenStatusExpired { + return nil, errors.New("该令牌已过期") + } if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { - token.Status = common.TokenStatusExpired - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + token.Status = common.TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { - token.Status = common.TokenStatusExhausted - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + // in this case, we can make sure the token is exhausted + token.Status = common.TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌额度已用尽") } - go func() { - token.AccessedTime = common.GetTimestamp() - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token" + err.Error()) - } - }() return token, nil } return nil, errors.New("无效的令牌") @@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + return nil + } + return increaseTokenQuota(id, quota) +} + +func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota + ?", quota), - "used_quota": gorm.Expr("used_quota - ?", quota), + "remain_quota": gorm.Expr("remain_quota + ?", quota), + "used_quota": gorm.Expr("used_quota - ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err @@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) + return nil + } + return decreaseTokenQuota(id, quota) +} + +func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota - ?", quota), - "used_quota": gorm.Expr("used_quota + ?", quota), + "remain_quota": gorm.Expr("remain_quota - ?", quota), + "used_quota": gorm.Expr("used_quota + ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err diff --git a/model/user.go b/model/user.go index 7c771840..cee4b023 100644 --- a/model/user.go +++ b/model/user.go @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { return user.Role >= common.RoleAdminUser } -func IsUserEnabled(userId int) bool { +func IsUserEnabled(userId int) (bool, error) { if userId == 0 { - return false + return false, errors.New("user id is empty") } var user User err := DB.Where("id = ?", userId).Select("status").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) - return false + return false, err } - return user.Status == common.UserStatusEnabled + return user.Status == common.UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { @@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, quota) + return nil + } + return increaseUserQuota(id, quota) +} + +func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } @@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, -quota) + return nil + } + return decreaseUserQuota(id, quota) +} + +func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) + return + } + updateUserUsedQuotaAndRequestCount(id, quota, 1) +} + +func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), - "request_count": gorm.Expr("request_count + ?", 1), + "request_count": gorm.Expr("request_count + ?", count), }, ).Error if err != nil { diff --git a/model/utils.go b/model/utils.go new file mode 100644 index 00000000..61734332 --- /dev/null +++ b/model/utils.go @@ -0,0 +1,75 @@ +package model + +import ( + "one-api/common" + "sync" + "time" +) + +const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock + +const ( + BatchUpdateTypeUserQuota = iota + BatchUpdateTypeTokenQuota + BatchUpdateTypeUsedQuotaAndRequestCount + BatchUpdateTypeChannelUsedQuota +) + +var batchUpdateStores []map[int]int +var batchUpdateLocks []sync.Mutex + +func init() { + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) + } +} + +func InitBatchUpdater() { + go func() { + for { + time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + batchUpdate() + } + }() +} + +func addNewRecord(type_ int, id int, value int) { + batchUpdateLocks[type_].Lock() + defer batchUpdateLocks[type_].Unlock() + if _, ok := batchUpdateStores[type_][id]; !ok { + batchUpdateStores[type_][id] = value + } else { + batchUpdateStores[type_][id] += value + } +} + +func batchUpdate() { + common.SysLog("batch update started") + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateLocks[i].Lock() + store := batchUpdateStores[i] + batchUpdateStores[i] = make(map[int]int) + batchUpdateLocks[i].Unlock() + + for key, value := range store { + switch i { + case BatchUpdateTypeUserQuota: + err := increaseUserQuota(key, value) + if err != nil { + common.SysError("failed to batch update user quota: " + err.Error()) + } + case BatchUpdateTypeTokenQuota: + err := increaseTokenQuota(key, value) + if err != nil { + common.SysError("failed to batch update token quota: " + err.Error()) + } + case BatchUpdateTypeUsedQuotaAndRequestCount: + updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect + case BatchUpdateTypeChannelUsedQuota: + updateChannelUsedQuota(key, value) + } + } + } + common.SysLog("batch update finished") +} diff --git a/router/api-router.go b/router/api-router.go index cc330d7e..d12bc54b 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) + apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) @@ -97,6 +98,7 @@ func SetApiRouter(router *gin.Engine) { } logRoute := apiRouter.Group("/log") logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) + logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs) logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) diff --git a/router/relay-router.go b/router/relay-router.go index a76e42cf..e84f02db 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -8,6 +8,7 @@ import ( ) func SetRelayRouter(router *gin.Engine) { + router.Use(middleware.CORS()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 43262e1a..7c8457d0 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; import { Link } from 'react-router-dom'; -import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; +import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { renderGroup, renderNumber } from '../helpers/render'; @@ -202,6 +202,7 @@ const ChannelsTable = () => { showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); + showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") } }; diff --git a/web/src/components/GitHubOAuth.js b/web/src/components/GitHubOAuth.js index 147d4d30..c43ed2a1 100644 --- a/web/src/components/GitHubOAuth.js +++ b/web/src/components/GitHubOAuth.js @@ -13,8 +13,8 @@ const GitHubOAuth = () => { let navigate = useNavigate(); - const sendCode = async (code, count) => { - const res = await API.get(`/api/oauth/github?code=${code}`); + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); const { success, message, data } = res.data; if (success) { if (message === 'bind') { @@ -36,13 +36,14 @@ const GitHubOAuth = () => { count++; setPrompt(`出现错误,第 ${count} 次重试中...`); await new Promise((resolve) => setTimeout(resolve, count * 2000)); - await sendCode(code, count); + await sendCode(code, state, count); } }; useEffect(() => { let code = searchParams.get('code'); - sendCode(code, 0).then(); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); }, []); return ( diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index 110dad46..b5c4e6f9 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -3,6 +3,7 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; import { API, getLogo, showError, showSuccess } from '../helpers'; +import { getOAuthState, onGitHubOAuthClicked } from './utils'; const LoginForm = () => { const [inputs, setInputs] = useState({ @@ -31,12 +32,6 @@ const LoginForm = () => { const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); - const onGitHubOAuthClicked = () => { - window.open( - `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` - ); - }; - const onWeChatLoginClicked = () => { setShowWeChatLoginModal(true); }; @@ -131,7 +126,7 @@ const LoginForm = () => { circular color='black' icon='github' - onClick={onGitHubOAuthClicked} + onClick={()=>onGitHubOAuthClicked(status.github_client_id)} /> ) : ( <> diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index f3741354..e266d79a 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -342,7 +342,7 @@ const LogsTable = () => { .map((log, idx) => { if (log.deleted) return <>; return ( - + {renderTimestamp(log.created_at)} { isAdminUser && ( diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 2adc7fa4..bf8b5ffd 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -1,8 +1,9 @@ import React, { useEffect, useState } from 'react'; import { Divider, Form, Grid, Header } from 'semantic-ui-react'; -import { API, showError, verifyJSON } from '../helpers'; +import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers'; const OperationSetting = () => { + let now = new Date(); let [inputs, setInputs] = useState({ QuotaForNewUser: 0, QuotaForInviter: 0, @@ -20,10 +21,11 @@ const OperationSetting = () => { DisplayInCurrencyEnabled: '', DisplayTokenStatEnabled: '', ApproximateTokenEnabled: '', - RetryTimes: 0, + RetryTimes: 0 }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); + let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago const getOptions = async () => { const res = await API.get('/api/option/'); @@ -130,6 +132,17 @@ const OperationSetting = () => { } }; + const deleteHistoryLogs = async () => { + console.log(inputs); + const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`${data} 条日志已清理!`); + return; + } + showError('日志清理失败:' + message); + }; + return ( @@ -179,12 +192,6 @@ const OperationSetting = () => { /> - { submitConfig('general').then(); }}>保存通用设置 +
+ 日志设置 +
+ + + + + { + setHistoryTimestamp(value); + }} /> + + { + deleteHistoryLogs().then(); + }}>清理历史日志 +
监控设置
diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index c7a303f9..6baf1f35 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom'; import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; import Turnstile from 'react-turnstile'; import { UserContext } from '../context/User'; +import { onGitHubOAuthClicked } from './utils'; const PersonalSetting = () => { const [userState, userDispatch] = useContext(UserContext); @@ -130,12 +131,6 @@ const PersonalSetting = () => { } }; - const openGitHubOAuth = () => { - window.open( - `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` - ); - }; - const sendVerificationCode = async () => { setDisableButton(true); if (inputs.email === '') return; @@ -249,7 +244,7 @@ const PersonalSetting = () => { { status.github_oauth && ( - + ) } { - if (customModel.trim() === '') return; - if (inputs.models.includes(customModel)) return; - let localModels = [...inputs.models]; - localModels.push(customModel); - let localModelOptions = []; - localModelOptions.push({ - key: customModel, - text: customModel, - value: customModel - }); - setModelOptions(modelOptions => { - return [...modelOptions, ...localModelOptions]; - }); - setCustomModel(''); - handleInputChange(null, { name: 'models', value: localModels }); - }}>填入 + } placeholder='输入自定义模型名称' value={customModel} onChange={(e, { value }) => { setCustomModel(value); }} + onKeyDown={(e) => { + if (e.key === 'Enter') { + addCustomModel(); + e.preventDefault(); + } + }} /> @@ -375,7 +411,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={type2secretPrompt(inputs.type)} onChange={handleInputChange} value={inputs.key} autoComplete='new-password' @@ -393,7 +429,7 @@ const EditChannel = () => { ) } { - inputs.type !== 3 && inputs.type !== 8 && ( + inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( { ) } + { + inputs.type === 22 && ( + + + + ) + }