diff --git a/README.md b/README.md index 45c8b603..b89c6be8 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 +如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 + 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 @@ -275,8 +277,9 @@ graph LR 不加的话将会使用负载均衡的方式使用多个渠道。 ### 环境变量 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 + 例子:`SESSION_SECRET=random_string` 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 @@ -303,6 +306,14 @@ graph LR + 例子:`CHANNEL_TEST_FREQUENCY=1440` 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` +10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`BATCH_UPDATE_ENABLED=true` + + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 +11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + + 例子:`BATCH_UPDATE_INTERVAL=5` +12. 请求频率限制: + + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 @@ -339,6 +350,7 @@ https://openai.justsong.cn 5. ChatGPT Next Web 报错:`Failed to fetch` + 部署的时候不要设置 `BASE_URL`。 + 检查你的接口地址和 API Key 有没有填对。 + + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 6. 报错:`当前分组负载已饱和,请稍后再试` + 上游通道 429 了。 diff --git a/common/constants.go b/common/constants.go index e5211e3d..69bd12a8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY +var BatchUpdateEnabled = false +var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) + const ( RoleGuestUser = 0 RoleCommonUser = 1 @@ -111,10 +114,10 @@ var ( // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = 180 + GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = 60 + GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -154,49 +157,53 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeAPI2D = 2 + ChannelTypeAzure = 3 + ChannelTypeCloseAI = 4 + ChannelTypeOpenAISB = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 + ChannelTypeAIProxyLibrary = 21 + ChannelTypeFastGPT = 22 ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 } diff --git a/controller/channel-test.go b/controller/channel-test.go index 686521ef..8c7e6f0d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,7 +14,7 @@ import ( "time" ) -func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { +func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { switch channel.Type { case common.ChannelTypePaLM: fallthrough @@ -32,6 +32,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: request.Model = "gpt-35-turbo" + defer func() { + if err != nil { + err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") + } + }() default: request.Model = "gpt-3.5-turbo" } diff --git a/controller/channel.go b/controller/channel.go index 8afc0eed..50b2b5f6 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) { } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") - channels := make([]model.Channel, 0) + channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go new file mode 100644 index 00000000..d0159ce8 --- /dev/null +++ b/controller/relay-aiproxy.go @@ -0,0 +1,220 @@ +package controller + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strconv" + "strings" +) + +// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 + +type AIProxyLibraryRequest struct { + Model string `json:"model"` + Query string `json:"query"` + LibraryId string `json:"libraryId"` + Stream bool `json:"stream"` +} + +type AIProxyLibraryError struct { + ErrCode int `json:"errCode"` + Message string `json:"message"` +} + +type AIProxyLibraryDocument struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type AIProxyLibraryResponse struct { + Success bool `json:"success"` + Answer string `json:"answer"` + Documents []AIProxyLibraryDocument `json:"documents"` + AIProxyLibraryError +} + +type AIProxyLibraryStreamResponse struct { + Content string `json:"content"` + Finish bool `json:"finish"` + Model string `json:"model"` + Documents []AIProxyLibraryDocument `json:"documents"` +} + +func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { + query := "" + if len(request.Messages) != 0 { + query = request.Messages[len(request.Messages)-1].Content + } + return &AIProxyLibraryRequest{ + Model: request.Model, + Stream: request.Stream, + Query: query, + } +} + +func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { + if len(documents) == 0 { + return "" + } + content := "\n\n参考文档:\n" + for i, document := range documents { + content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) + } + return content +} + +func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { + content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: content, + }, + FinishReason: "stop", + } + fullTextResponse := OpenAITextResponse{ + Id: common.GetUUID(), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + } + return &fullTextResponse +} + +func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = aiProxyDocuments2Markdown(documents) + choice.FinishReason = &stopFinishReason + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = response.Content + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: response.Model, + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + var documents []AIProxyLibraryDocument + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var AIProxyLibraryResponse AIProxyLibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + response := documentsAIProxyLibrary(documents) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var AIProxyLibraryResponse AIProxyLibraryResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if AIProxyLibraryResponse.ErrCode != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: AIProxyLibraryResponse.Message, + Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), + Code: AIProxyLibraryResponse.ErrCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 708f94cb..b190a999 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -22,6 +22,7 @@ const ( APITypeZhipu APITypeAli APITypeXunfei + APITypeAIProxyLibrary ) var httpClient *http.Client @@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAli case common.ChannelTypeXunfei: apiType = APITypeXunfei + case common.ChannelTypeAIProxyLibrary: + apiType = APITypeAIProxyLibrary } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -174,6 +177,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if relayMode == RelayModeEmbeddings { fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" } + case APITypeAIProxyLibrary: + fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } var promptTokens int var completionTokens int @@ -274,6 +279,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeAIProxyLibrary: + aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) + aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + jsonStr, err := json.Marshal(aiProxyLibraryRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } var req *http.Request @@ -313,6 +326,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + default: + req.Header.Set("Authorization", "Bearer "+apiKey) } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) @@ -373,7 +388,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) } } @@ -534,6 +548,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } else { return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) } + case APITypeAIProxyLibrary: + if isStream { + err, usage := aiProxyLibraryStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + err, usage := aiProxyLibraryHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/i18n/en.json b/i18n/en.json index aed65979..9b2ca4c8 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -523,5 +523,6 @@ "按照如下格式输入:": "Enter in the following format:", "模型版本": "Model version", "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", - "点击查看": "click to view" + "点击查看": "click to view", + "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" } diff --git a/main.go b/main.go index 9fb0a73e..8c5f2f31 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,11 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { + common.BatchUpdateEnabled = true + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + model.InitBatchUpdater() + } controller.InitTokenEncoders() // Initialize HTTP server diff --git a/middleware/auth.go b/middleware/auth.go index 060e005c..95516d6e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) { c.Abort() return } - if !model.CacheIsUserEnabled(token.UserId) { + userEnabled, err := model.IsUserEnabled(token.UserId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if !userEnabled { c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "message": "用户已被封禁", diff --git a/middleware/distributor.go b/middleware/distributor.go index 93827c95..e8b76596 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -115,8 +115,13 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.ModelMapping) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) - if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { + switch channel.Type { + case common.ChannelTypeAzure: c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + case common.ChannelTypeAIProxyLibrary: + c.Set("library_id", channel.Other) } c.Next() } diff --git a/model/cache.go b/model/cache.go index 55fbba9b..c28952b5 100644 --- a/model/cache.go +++ b/model/cache.go @@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error { return err } -func CacheIsUserEnabled(userId int) bool { +func CacheIsUserEnabled(userId int) (bool, error) { if !common.RedisEnabled { return IsUserEnabled(userId) } enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) - if err != nil { - status := common.UserStatusDisabled - if IsUserEnabled(userId) { - status = common.UserStatusEnabled - } - enabled = fmt.Sprintf("%d", status) - err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) - } + if err == nil { + return enabled == "1", nil } - return enabled == "1" + + userEnabled, err := IsUserEnabled(userId) + if err != nil { + return false, err + } + enabled = "0" + if userEnabled { + enabled = "1" + } + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user enabled error: " + err.Error()) + } + return userEnabled, err } var group2model2channels map[string]map[string][]*Channel diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..5c495bab 100644 --- a/model/channel.go +++ b/model/channel.go @@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) + return + } + updateChannelUsedQuota(id, quota) +} + +func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/token.go b/model/token.go index 7cd226c6..0fa984d3 100644 --- a/model/token.go +++ b/model/token.go @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err == nil { + if token.Status == common.TokenStatusExhausted { + return nil, errors.New("该令牌额度已用尽") + } else if token.Status == common.TokenStatusExpired { + return nil, errors.New("该令牌已过期") + } if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { - token.Status = common.TokenStatusExpired - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + token.Status = common.TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { - token.Status = common.TokenStatusExhausted - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + // in this case, we can make sure the token is exhausted + token.Status = common.TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌额度已用尽") } - go func() { - token.AccessedTime = common.GetTimestamp() - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token" + err.Error()) - } - }() return token, nil } return nil, errors.New("无效的令牌") @@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + return nil + } + return increaseTokenQuota(id, quota) +} + +func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota + ?", quota), - "used_quota": gorm.Expr("used_quota - ?", quota), + "remain_quota": gorm.Expr("remain_quota + ?", quota), + "used_quota": gorm.Expr("used_quota - ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err @@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) + return nil + } + return decreaseTokenQuota(id, quota) +} + +func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota - ?", quota), - "used_quota": gorm.Expr("used_quota + ?", quota), + "remain_quota": gorm.Expr("remain_quota - ?", quota), + "used_quota": gorm.Expr("used_quota + ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err diff --git a/model/user.go b/model/user.go index 7c771840..cee4b023 100644 --- a/model/user.go +++ b/model/user.go @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { return user.Role >= common.RoleAdminUser } -func IsUserEnabled(userId int) bool { +func IsUserEnabled(userId int) (bool, error) { if userId == 0 { - return false + return false, errors.New("user id is empty") } var user User err := DB.Where("id = ?", userId).Select("status").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) - return false + return false, err } - return user.Status == common.UserStatusEnabled + return user.Status == common.UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { @@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, quota) + return nil + } + return increaseUserQuota(id, quota) +} + +func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } @@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, -quota) + return nil + } + return decreaseUserQuota(id, quota) +} + +func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) + return + } + updateUserUsedQuotaAndRequestCount(id, quota, 1) +} + +func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), - "request_count": gorm.Expr("request_count + ?", 1), + "request_count": gorm.Expr("request_count + ?", count), }, ).Error if err != nil { diff --git a/model/utils.go b/model/utils.go new file mode 100644 index 00000000..61734332 --- /dev/null +++ b/model/utils.go @@ -0,0 +1,75 @@ +package model + +import ( + "one-api/common" + "sync" + "time" +) + +const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock + +const ( + BatchUpdateTypeUserQuota = iota + BatchUpdateTypeTokenQuota + BatchUpdateTypeUsedQuotaAndRequestCount + BatchUpdateTypeChannelUsedQuota +) + +var batchUpdateStores []map[int]int +var batchUpdateLocks []sync.Mutex + +func init() { + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) + } +} + +func InitBatchUpdater() { + go func() { + for { + time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + batchUpdate() + } + }() +} + +func addNewRecord(type_ int, id int, value int) { + batchUpdateLocks[type_].Lock() + defer batchUpdateLocks[type_].Unlock() + if _, ok := batchUpdateStores[type_][id]; !ok { + batchUpdateStores[type_][id] = value + } else { + batchUpdateStores[type_][id] += value + } +} + +func batchUpdate() { + common.SysLog("batch update started") + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateLocks[i].Lock() + store := batchUpdateStores[i] + batchUpdateStores[i] = make(map[int]int) + batchUpdateLocks[i].Unlock() + + for key, value := range store { + switch i { + case BatchUpdateTypeUserQuota: + err := increaseUserQuota(key, value) + if err != nil { + common.SysError("failed to batch update user quota: " + err.Error()) + } + case BatchUpdateTypeTokenQuota: + err := increaseTokenQuota(key, value) + if err != nil { + common.SysError("failed to batch update token quota: " + err.Error()) + } + case BatchUpdateTypeUsedQuotaAndRequestCount: + updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect + case BatchUpdateTypeChannelUsedQuota: + updateChannelUsedQuota(key, value) + } + } + } + common.SysLog("batch update finished") +} diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index bacb7689..c981e261 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -324,7 +324,7 @@ const LogsTable = () => { .map((log, idx) => { if (log.deleted) return <>; return ( - + {renderTimestamp(log.created_at)} { isAdminUser && ( diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index b1631479..e42afc6e 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -9,6 +9,8 @@ export const CHANNEL_OPTIONS = [ { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, + { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 028000d6..78ff1952 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = { 'gpt-4-32k-0314': 'gpt-4-32k' }; +function type2secretPrompt(type) { + // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') + switch (type) { + case 15: + return '按照如下格式输入:APIKey|SecretKey'; + case 18: + return '按照如下格式输入:APPID|APISecret|APIKey'; + case 22: + return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; + default: + return '请输入渠道对应的鉴权密钥'; + } +} + const EditChannel = () => { const params = useParams(); const navigate = useNavigate(); @@ -193,6 +207,24 @@ const EditChannel = () => { } }; + const addCustomModel = () => { + if (customModel.trim() === '') return; + if (inputs.models.includes(customModel)) return; + let localModels = [...inputs.models]; + localModels.push(customModel); + let localModelOptions = []; + localModelOptions.push({ + key: customModel, + text: customModel, + value: customModel + }); + setModelOptions(modelOptions => { + return [...modelOptions, ...localModelOptions]; + }); + setCustomModel(''); + handleInputChange(null, { name: 'models', value: localModels }); + }; + return ( <> @@ -295,6 +327,20 @@ const EditChannel = () => { ) } + { + inputs.type === 21 && ( + + + + ) + } { }}>清除所有模型 { - if (customModel.trim() === '') return; - if (inputs.models.includes(customModel)) return; - let localModels = [...inputs.models]; - localModels.push(customModel); - let localModelOptions = []; - localModelOptions.push({ - key: customModel, - text: customModel, - value: customModel - }); - setModelOptions(modelOptions => { - return [...modelOptions, ...localModelOptions]; - }); - setCustomModel(''); - handleInputChange(null, { name: 'models', value: localModels }); - }}>填入 + } placeholder='输入自定义模型名称' value={customModel} onChange={(e, { value }) => { setCustomModel(value); }} + onKeyDown={(e) => { + if (e.key === 'Enter') { + addCustomModel(); + e.preventDefault(); + } + }} /> @@ -375,7 +411,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={type2secretPrompt(inputs.type)} onChange={handleInputChange} value={inputs.key} autoComplete='new-password' @@ -393,7 +429,7 @@ const EditChannel = () => { ) } { - inputs.type !== 3 && inputs.type !== 8 && ( + inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( { ) } + { + inputs.type === 22 && ( + + + + ) + }