From ef2c5abb5b3d1fc5bffdb358758156e790db38f8 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 30 Aug 2023 20:51:37 +0800 Subject: [PATCH 01/14] docs: update README --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 45c8b603..a0f3bcb9 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 +如果启动失败,请添加 `--privileged=true`,具体参考 #482。 + 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 @@ -275,8 +277,9 @@ graph LR 不加的话将会使用负载均衡的方式使用多个渠道。 ### 环境变量 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 + 例子:`SESSION_SECRET=random_string` 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 @@ -339,6 +342,7 @@ https://openai.justsong.cn 5. ChatGPT Next Web 报错:`Failed to fetch` + 部署的时候不要设置 `BASE_URL`。 + 检查你的接口地址和 API Key 有没有填对。 + + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 6. 报错:`当前分组负载已饱和,请稍后再试` + 上游通道 429 了。 From abbf2fded0a694390e0f025b63e757d23b86155c Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 30 Aug 2023 21:15:56 +0800 Subject: [PATCH 02/14] perf: preallocate array capacity --- controller/channel.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/channel.go b/controller/channel.go index 8afc0eed..50b2b5f6 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) { } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") - channels := make([]model.Channel, 0) + channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue From f0d5e102a3dc22289e00860f63ebd0f125531641 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 30 Aug 2023 21:43:01 +0800 Subject: [PATCH 03/14] fix: fix log table use created_at as key instead of id Co-authored-by: 13714733197 <13714733197@163.com> --- web/src/components/LogsTable.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index bacb7689..c981e261 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -324,7 +324,7 @@ const LogsTable = () => { .map((log, idx) => { if (log.deleted) return <>; return ( - + {renderTimestamp(log.created_at)} { isAdminUser && ( From 04acdb1ccb059d7dd86f4f397f4a027259edc74d Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 12:51:59 +0800 Subject: [PATCH 04/14] feat: support aiproxy's library --- common/constants.go | 44 ++--- controller/relay-aiproxy.go | 220 +++++++++++++++++++++++++ controller/relay-text.go | 35 ++++ middleware/distributor.go | 7 +- web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 14 ++ 6 files changed, 299 insertions(+), 22 deletions(-) create mode 100644 controller/relay-aiproxy.go diff --git a/common/constants.go b/common/constants.go index e5211e3d..66ca06f4 100644 --- a/common/constants.go +++ b/common/constants.go @@ -154,27 +154,28 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeAPI2D = 2 + ChannelTypeAzure = 3 + ChannelTypeCloseAI = 4 + ChannelTypeOpenAISB = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 + ChannelTypeAIProxyLibrary = 21 ) var ChannelBaseURLs = []string{ @@ -199,4 +200,5 @@ var ChannelBaseURLs = []string{ "", // 18 "https://ai.360.cn", // 19 "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 } diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go new file mode 100644 index 00000000..d0159ce8 --- /dev/null +++ b/controller/relay-aiproxy.go @@ -0,0 +1,220 @@ +package controller + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strconv" + "strings" +) + +// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 + +type AIProxyLibraryRequest struct { + Model string `json:"model"` + Query string `json:"query"` + LibraryId string `json:"libraryId"` + Stream bool `json:"stream"` +} + +type AIProxyLibraryError struct { + ErrCode int `json:"errCode"` + Message string `json:"message"` +} + +type AIProxyLibraryDocument struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type AIProxyLibraryResponse struct { + Success bool `json:"success"` + Answer string `json:"answer"` + Documents []AIProxyLibraryDocument `json:"documents"` + AIProxyLibraryError +} + +type AIProxyLibraryStreamResponse struct { + Content string `json:"content"` + Finish bool `json:"finish"` + Model string `json:"model"` + Documents []AIProxyLibraryDocument `json:"documents"` +} + +func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { + query := "" + if len(request.Messages) != 0 { + query = request.Messages[len(request.Messages)-1].Content + } + return &AIProxyLibraryRequest{ + Model: request.Model, + Stream: request.Stream, + Query: query, + } +} + +func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { + if len(documents) == 0 { + return "" + } + content := "\n\n参考文档:\n" + for i, document := range documents { + content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) + } + return content +} + +func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { + content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: content, + }, + FinishReason: "stop", + } + fullTextResponse := OpenAITextResponse{ + Id: common.GetUUID(), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + } + return &fullTextResponse +} + +func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = aiProxyDocuments2Markdown(documents) + choice.FinishReason = &stopFinishReason + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = response.Content + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: response.Model, + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + var documents []AIProxyLibraryDocument + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var AIProxyLibraryResponse AIProxyLibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + response := documentsAIProxyLibrary(documents) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var AIProxyLibraryResponse AIProxyLibraryResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if AIProxyLibraryResponse.ErrCode != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: AIProxyLibraryResponse.Message, + Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), + Code: AIProxyLibraryResponse.ErrCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 624b9d01..6f410f96 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -22,6 +22,7 @@ const ( APITypeZhipu APITypeAli APITypeXunfei + APITypeAIProxyLibrary ) var httpClient *http.Client @@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAli case common.ChannelTypeXunfei: apiType = APITypeXunfei + case common.ChannelTypeAIProxyLibrary: + apiType = APITypeAIProxyLibrary } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -171,6 +174,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + case APITypeAIProxyLibrary: + fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } var promptTokens int var completionTokens int @@ -263,6 +268,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeAIProxyLibrary: + aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) + aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + jsonStr, err := json.Marshal(aiProxyLibraryRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } var req *http.Request @@ -302,6 +315,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + default: + req.Header.Set("Authorization", "Bearer "+apiKey) } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) @@ -516,6 +531,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } else { return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) } + case APITypeAIProxyLibrary: + if isStream { + err, usage := aiProxyLibraryStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + err, usage := aiProxyLibraryHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/middleware/distributor.go b/middleware/distributor.go index 93827c95..e8b76596 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -115,8 +115,13 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.ModelMapping) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) - if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { + switch channel.Type { + case common.ChannelTypeAzure: c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + case common.ChannelTypeAIProxyLibrary: + c.Set("library_id", channel.Other) } c.Next() } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index b1631479..1eecfc5a 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [ { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index da11b588..7e150b4a 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -295,6 +295,20 @@ const EditChannel = () => { ) } + { + inputs.type === 21 && ( + + + + ) + } Date: Sun, 3 Sep 2023 14:58:20 +0800 Subject: [PATCH 05/14] feat: add batch update support (close #414) --- README.md | 4 +++ common/constants.go | 3 ++ main.go | 5 +++ model/channel.go | 8 +++++ model/token.go | 16 ++++++++++ model/user.go | 26 +++++++++++++++- model/utils.go | 75 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 model/utils.go diff --git a/README.md b/README.md index a0f3bcb9..a2105df2 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,10 @@ graph LR + 例子:`CHANNEL_TEST_FREQUENCY=1440` 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` +10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`BATCH_UPDATE_ENABLED=true` +11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + + 例子:`BATCH_UPDATE_INTERVAL=5` ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 66ca06f4..b272fbe6 100644 --- a/common/constants.go +++ b/common/constants.go @@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY +var BatchUpdateEnabled = false +var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/main.go b/main.go index 9fb0a73e..8c5f2f31 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,11 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { + common.BatchUpdateEnabled = true + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + model.InitBatchUpdater() + } controller.InitTokenEncoders() // Initialize HTTP server diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..5c495bab 100644 --- a/model/channel.go +++ b/model/channel.go @@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) + return + } + updateChannelUsedQuota(id, quota) +} + +func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/token.go b/model/token.go index 7cd226c6..dfda27e3 100644 --- a/model/token.go +++ b/model/token.go @@ -131,6 +131,14 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + return nil + } + return increaseTokenQuota(id, quota) +} + +func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), @@ -144,6 +152,14 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) + return nil + } + return decreaseTokenQuota(id, quota) +} + +func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), diff --git a/model/user.go b/model/user.go index 7c771840..67511267 100644 --- a/model/user.go +++ b/model/user.go @@ -275,6 +275,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, quota) + return nil + } + return increaseUserQuota(id, quota) +} + +func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } @@ -283,6 +291,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, -quota) + return nil + } + return decreaseUserQuota(id, quota) +} + +func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -293,10 +309,18 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) + return + } + updateUserUsedQuotaAndRequestCount(id, quota, 1) +} + +func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), - "request_count": gorm.Expr("request_count + ?", 1), + "request_count": gorm.Expr("request_count + ?", count), }, ).Error if err != nil { diff --git a/model/utils.go b/model/utils.go new file mode 100644 index 00000000..61734332 --- /dev/null +++ b/model/utils.go @@ -0,0 +1,75 @@ +package model + +import ( + "one-api/common" + "sync" + "time" +) + +const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock + +const ( + BatchUpdateTypeUserQuota = iota + BatchUpdateTypeTokenQuota + BatchUpdateTypeUsedQuotaAndRequestCount + BatchUpdateTypeChannelUsedQuota +) + +var batchUpdateStores []map[int]int +var batchUpdateLocks []sync.Mutex + +func init() { + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) + } +} + +func InitBatchUpdater() { + go func() { + for { + time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + batchUpdate() + } + }() +} + +func addNewRecord(type_ int, id int, value int) { + batchUpdateLocks[type_].Lock() + defer batchUpdateLocks[type_].Unlock() + if _, ok := batchUpdateStores[type_][id]; !ok { + batchUpdateStores[type_][id] = value + } else { + batchUpdateStores[type_][id] += value + } +} + +func batchUpdate() { + common.SysLog("batch update started") + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateLocks[i].Lock() + store := batchUpdateStores[i] + batchUpdateStores[i] = make(map[int]int) + batchUpdateLocks[i].Unlock() + + for key, value := range store { + switch i { + case BatchUpdateTypeUserQuota: + err := increaseUserQuota(key, value) + if err != nil { + common.SysError("failed to batch update user quota: " + err.Error()) + } + case BatchUpdateTypeTokenQuota: + err := increaseTokenQuota(key, value) + if err != nil { + common.SysError("failed to batch update token quota: " + err.Error()) + } + case BatchUpdateTypeUsedQuotaAndRequestCount: + updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect + case BatchUpdateTypeChannelUsedQuota: + updateChannelUsedQuota(key, value) + } + } + } + common.SysLog("batch update finished") +} From 9db93316c4bd86c66fb74d5e7c36eaf4f3fde697 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 15:12:54 +0800 Subject: [PATCH 06/14] docs: update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a2105df2..382dfb35 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,7 @@ graph LR + 例子:`POLLING_INTERVAL=5` 10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` From 7e575abb951475c95adfe32d82994b81d50a7094 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 15:50:49 +0800 Subject: [PATCH 07/14] feat: add channel type FastGPT --- common/constants.go | 46 ++++++++++++++------------ web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 32 ++++++++++++++++-- 3 files changed, 55 insertions(+), 24 deletions(-) diff --git a/common/constants.go b/common/constants.go index b272fbe6..4a3f3f2b 100644 --- a/common/constants.go +++ b/common/constants.go @@ -179,29 +179,31 @@ const ( ChannelType360 = 19 ChannelTypeOpenRouter = 20 ChannelTypeAIProxyLibrary = 21 + ChannelTypeFastGPT = 22 ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 1eecfc5a..e42afc6e 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [ { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 7e150b4a..d75e67eb 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = { 'gpt-4-32k-0314': 'gpt-4-32k' }; +function type2secretPrompt(type) { + // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') + switch (type) { + case 15: + return "按照如下格式输入:APIKey|SecretKey" + case 18: + return "按照如下格式输入:APPID|APISecret|APIKey" + case 22: + return "按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041" + default: + return "请输入渠道对应的鉴权密钥" + } +} + const EditChannel = () => { const params = useParams(); const navigate = useNavigate(); @@ -389,7 +403,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={type2secretPrompt(inputs.type)} onChange={handleInputChange} value={inputs.key} autoComplete='new-password' @@ -407,7 +421,7 @@ const EditChannel = () => { ) } { - inputs.type !== 3 && inputs.type !== 8 && ( + inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( { ) } + { + inputs.type === 22 && ( + + + + ) + } From 621eb91b46a1909bf439cc50688bcdbc689e28b2 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:31:58 +0800 Subject: [PATCH 08/14] chore: pass through error out --- controller/relay-text.go | 1 - middleware/auth.go | 13 +++++++++++- model/cache.go | 29 ++++++++++++++++----------- model/token.go | 43 ++++++++++++++++++++++------------------ model/user.go | 9 ++++----- 5 files changed, 57 insertions(+), 38 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index 6f410f96..c6659799 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -377,7 +377,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) } } diff --git a/middleware/auth.go b/middleware/auth.go index 060e005c..95516d6e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) { c.Abort() return } - if !model.CacheIsUserEnabled(token.UserId) { + userEnabled, err := model.IsUserEnabled(token.UserId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if !userEnabled { c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "message": "用户已被封禁", diff --git a/model/cache.go b/model/cache.go index 55fbba9b..c28952b5 100644 --- a/model/cache.go +++ b/model/cache.go @@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error { return err } -func CacheIsUserEnabled(userId int) bool { +func CacheIsUserEnabled(userId int) (bool, error) { if !common.RedisEnabled { return IsUserEnabled(userId) } enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) - if err != nil { - status := common.UserStatusDisabled - if IsUserEnabled(userId) { - status = common.UserStatusEnabled - } - enabled = fmt.Sprintf("%d", status) - err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) - } + if err == nil { + return enabled == "1", nil } - return enabled == "1" + + userEnabled, err := IsUserEnabled(userId) + if err != nil { + return false, err + } + enabled = "0" + if userEnabled { + enabled = "1" + } + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user enabled error: " + err.Error()) + } + return userEnabled, err } var group2model2channels map[string]map[string][]*Channel diff --git a/model/token.go b/model/token.go index dfda27e3..0fa984d3 100644 --- a/model/token.go +++ b/model/token.go @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err == nil { + if token.Status == common.TokenStatusExhausted { + return nil, errors.New("该令牌额度已用尽") + } else if token.Status == common.TokenStatusExpired { + return nil, errors.New("该令牌已过期") + } if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { - token.Status = common.TokenStatusExpired - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + token.Status = common.TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { - token.Status = common.TokenStatusExhausted - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + // in this case, we can make sure the token is exhausted + token.Status = common.TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌额度已用尽") } - go func() { - token.AccessedTime = common.GetTimestamp() - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token" + err.Error()) - } - }() return token, nil } return nil, errors.New("无效的令牌") @@ -141,8 +144,9 @@ func IncreaseTokenQuota(id int, quota int) (err error) { func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota + ?", quota), - "used_quota": gorm.Expr("used_quota - ?", quota), + "remain_quota": gorm.Expr("remain_quota + ?", quota), + "used_quota": gorm.Expr("used_quota - ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err @@ -162,8 +166,9 @@ func DecreaseTokenQuota(id int, quota int) (err error) { func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota - ?", quota), - "used_quota": gorm.Expr("used_quota + ?", quota), + "remain_quota": gorm.Expr("remain_quota - ?", quota), + "used_quota": gorm.Expr("used_quota + ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err diff --git a/model/user.go b/model/user.go index 67511267..cee4b023 100644 --- a/model/user.go +++ b/model/user.go @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { return user.Role >= common.RoleAdminUser } -func IsUserEnabled(userId int) bool { +func IsUserEnabled(userId int) (bool, error) { if userId == 0 { - return false + return false, errors.New("user id is empty") } var user User err := DB.Where("id = ?", userId).Select("status").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) - return false + return false, err } - return user.Status == common.UserStatusEnabled + return user.Status == common.UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { From 276163affdef85d5adf9044fdbe0ab1b32455fcd Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:40:58 +0800 Subject: [PATCH 09/14] fix: press enter to submit custom model name --- web/src/pages/Channel/EditChannel.js | 50 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index d75e67eb..2ad3dc6a 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -14,13 +14,13 @@ function type2secretPrompt(type) { // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') switch (type) { case 15: - return "按照如下格式输入:APIKey|SecretKey" + return '按照如下格式输入:APIKey|SecretKey'; case 18: - return "按照如下格式输入:APPID|APISecret|APIKey" + return '按照如下格式输入:APPID|APISecret|APIKey'; case 22: - return "按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041" + return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; default: - return "请输入渠道对应的鉴权密钥" + return '请输入渠道对应的鉴权密钥'; } } @@ -207,6 +207,24 @@ const EditChannel = () => { } }; + const addCustomModel = () => { + if (customModel.trim() === '') return; + if (inputs.models.includes(customModel)) return; + let localModels = [...inputs.models]; + localModels.push(customModel); + let localModelOptions = []; + localModelOptions.push({ + key: customModel, + text: customModel, + value: customModel + }); + setModelOptions(modelOptions => { + return [...modelOptions, ...localModelOptions]; + }); + setCustomModel(''); + handleInputChange(null, { name: 'models', value: localModels }); + }; + return ( <> @@ -350,29 +368,19 @@ const EditChannel = () => { }}>清除所有模型 { - if (customModel.trim() === '') return; - if (inputs.models.includes(customModel)) return; - let localModels = [...inputs.models]; - localModels.push(customModel); - let localModelOptions = []; - localModelOptions.push({ - key: customModel, - text: customModel, - value: customModel - }); - setModelOptions(modelOptions => { - return [...modelOptions, ...localModelOptions]; - }); - setCustomModel(''); - handleInputChange(null, { name: 'models', value: localModels }); - }}>填入 + } placeholder='输入自定义模型名称' value={customModel} onChange={(e, { value }) => { setCustomModel(value); }} + onKeyDown={(e) => { + if (e.key === 'Enter') { + addCustomModel(); + e.preventDefault(); + } + }} /> From a721a5b6f9d4a9ce721ce323c5eeb0be34a3b97b Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:46:07 +0800 Subject: [PATCH 10/14] chore: add error prompt for Azure --- controller/channel-test.go | 7 ++++++- i18n/en.json | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 686521ef..8c7e6f0d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,7 +14,7 @@ import ( "time" ) -func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { +func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { switch channel.Type { case common.ChannelTypePaLM: fallthrough @@ -32,6 +32,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: request.Model = "gpt-35-turbo" + defer func() { + if err != nil { + err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") + } + }() default: request.Model = "gpt-3.5-turbo" } diff --git a/i18n/en.json b/i18n/en.json index aed65979..9b2ca4c8 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -523,5 +523,6 @@ "按照如下格式输入:": "Enter in the following format:", "模型版本": "Model version", "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", - "点击查看": "click to view" + "点击查看": "click to view", + "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" } From 0f949c3782f302b41b07475696fdf63d6cdb99ff Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:49:41 +0800 Subject: [PATCH 11/14] docs: update README (close #482) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 382dfb35..92eb7f6a 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 -如果启动失败,请添加 `--privileged=true`,具体参考 #482。 +如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482。 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 From c55bb6781824f384f940e7daf1b94cd2f5fe3ff9 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:50:00 +0800 Subject: [PATCH 12/14] docs: update README (close #482) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 92eb7f6a..4f25d3f5 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 -如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482。 +如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 From bd6fe1e93cf96d79d3a6f7872a6df1c367872ed6 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:56:37 +0800 Subject: [PATCH 13/14] feat: able to config rate limit (close #477) --- README.md | 3 +++ common/constants.go | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4f25d3f5..b89c6be8 100644 --- a/README.md +++ b/README.md @@ -311,6 +311,9 @@ graph LR + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` +12. 请求频率限制: + + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 4a3f3f2b..69bd12a8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -114,10 +114,10 @@ var ( // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = 180 + GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = 60 + GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 From d0a0e871e117591fdf018e619c5d71ee371331f6 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 3 Sep 2023 22:12:35 +0800 Subject: [PATCH 14/14] fix: support ali's embedding model (#481, close #469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat:支持阿里的 embedding 模型 * fix: add to model list --------- Co-authored-by: JustSong Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com> --- common/model-ratio.go | 7 ++- controller/model.go | 9 +++ controller/relay-ali.go | 88 ++++++++++++++++++++++++++++ controller/relay-baidu.go | 15 +---- controller/relay-text.go | 24 +++++++- controller/relay.go | 19 ++++++ web/src/pages/Channel/EditChannel.js | 2 +- 7 files changed, 144 insertions(+), 20 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 70758805..eeb23e07 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag - "qwen-plus-v1": 0.5715, // Same as above - "SparkDesk": 0.8572, // TBD + "qwen-v1": 0.8572, // ¥0.012 / 1k tokens + "qwen-plus-v1": 1, // ¥0.014 / 1k tokens + "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 88f95f7b..637ebe10 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,15 @@ func init() { Root: "qwen-plus-v1", Parent: nil, }, + { + Id: "text-embedding-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "text-embedding-v1", + Parent: nil, + }, { Id: "SparkDesk", Object: "model", diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 9dca9a89..50dc743c 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -35,6 +35,29 @@ type AliChatRequest struct { Parameters AliParameters `json:"parameters,omitempty"` } +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + type AliError struct { Code string `json:"code"` Message string `json:"message"` @@ -44,6 +67,7 @@ type AliError struct { type AliUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` } type AliOutput struct { @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { } } +func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { + return &AliEmbeddingRequest{ + Model: "text-embedding-v1", + Input: struct { + Texts []string `json:"texts"` + }{ + Texts: request.ParseInput(), + }, + } +} + +func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var aliResponse AliEmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&aliResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Code != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Model: "text-embedding-v1", + Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + } + + for _, item := range response.Output.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { choice := OpenAITextResponseChoice{ Index: 0, diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 39f31a9a..ed08ac04 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom } func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - baiduEmbeddingRequest := BaiduEmbeddingRequest{ - Input: nil, + return &BaiduEmbeddingRequest{ + Input: request.ParseInput(), } - switch request.Input.(type) { - case string: - baiduEmbeddingRequest.Input = []string{request.Input.(string)} - case []any: - for _, item := range request.Input.([]any) { - if str, ok := item.(string); ok { - baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) - } - } - } - return &baiduEmbeddingRequest } func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { diff --git a/controller/relay-text.go b/controller/relay-text.go index c6659799..b190a999 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -174,6 +174,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } case APITypeAIProxyLibrary: fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } @@ -262,8 +265,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeAli: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err := json.Marshal(aliRequest) + var jsonStr []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliEmbeddingRequest) + default: + aliRequest := requestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliRequest) + } if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } @@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := aliHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + err, usage = aliHandler(c, resp) + } if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 056d42d3..d20663f6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct { Functions any `json:"functions,omitempty"` } +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 2ad3dc6a..78ff1952 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -67,7 +67,7 @@ const EditChannel = () => { localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; break; case 17: - localModels = ['qwen-v1', 'qwen-plus-v1']; + localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1']; break; case 16: localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];