diff --git a/README.md b/README.md index bcf8b664..18835948 100644 --- a/README.md +++ b/README.md @@ -70,17 +70,18 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 9. 支持渠道**设置模型列表**。 10. 支持**查看额度明细**。 -11. 支持发布公告,设置充值链接,设置新用户初始额度。 -12. 支持丰富的**自定义**设置, +11. 支持**用户邀请奖励**。 +12. 支持发布公告,设置充值链接,设置新用户初始额度。 +13. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -13. 支持通过系统访问令牌访问管理 API。 -14. 支持 Cloudflare Turnstile 用户校验。 -15. 支持用户管理,支持**多种用户登录注册方式**: +14. 支持通过系统访问令牌访问管理 API。 +15. 支持 Cloudflare Turnstile 用户校验。 +16. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册以及通过邮箱进行密码重置。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 -16. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 +17. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 ## 部署 ### 基于 Docker 进行部署 @@ -157,9 +158,9 @@ sudo service nginx restart ### 宝塔部署教程 -详见[#175](https://github.com/songquanpeng/one-api/issues/175)。 +详见 [#175](https://github.com/songquanpeng/one-api/issues/175)。 -如果部署后访问出现空白页面,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。 +如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。 ### 部署第三方服务配合 One API 使用 > 欢迎 PR 添加更多示例。 @@ -277,7 +278,7 @@ https://openai.justsong.cn + 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。 ## 注意 -本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及法律法规的情况下使用,不得用于非法用途。 +本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 本项目使用 MIT 协议进行开源,请以某种方式保留 One API 的版权信息。 diff --git a/controller/relay-image.go b/controller/relay-image.go new file mode 100644 index 00000000..c5311272 --- /dev/null +++ b/controller/relay-image.go @@ -0,0 +1,34 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "io" + "net/http" +) + +func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + // TODO: this part is not finished + req, err := http.NewRequest(c.Request.Method, c.Request.RequestURI, c.Request.Body) + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusOK) + } + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusOK) + } + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusOK) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusOK) + } + return nil +} diff --git a/controller/relay-text.go b/controller/relay-text.go new file mode 100644 index 00000000..2a07f2fa --- /dev/null +++ b/controller/relay-text.go @@ -0,0 +1,266 @@ +package controller + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/model" + "strings" +) + +func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + channelType := c.GetInt("channel") + tokenId := c.GetInt("token_id") + consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") + var textRequest GeneralOpenAIRequest + if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { + err := common.UnmarshalBodyReusable(c, &textRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + } + if relayMode == RelayModeModeration && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + if channelType == common.ChannelTypeCustom { + baseURL = c.GetString("base_url") + } else if channelType == common.ChannelTypeOpenAI { + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + } + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + if channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + requestURL := strings.Split(requestURL, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) + baseURL = c.GetString("base_url") + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := textRequest.Model + model_ = strings.Replace(model_, ".", "", -1) + // https://github.com/songquanpeng/one-api/issues/67 + model_ = strings.TrimSuffix(model_, "-0301") + model_ = strings.TrimSuffix(model_, "-0314") + model_ = strings.TrimSuffix(model_, "-0613") + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + } else if channelType == common.ChannelTypePaLM { + err := relayPaLM(textRequest, c) + return err + } + var promptTokens int + switch relayMode { + case RelayModeChatCompletions: + promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + case RelayModeCompletions: + promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) + case RelayModeModeration: + promptTokens = countTokenInput(textRequest.Input, textRequest.Model) + } + preConsumedTokens := common.PreConsumedQuota + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + textRequest.MaxTokens + } + modelRatio := common.GetModelRatio(textRequest.Model) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(preConsumedTokens) * ratio) + if consumeQuota { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK) + } + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusOK) + } + if channelType == common.ChannelTypeAzure { + key := c.Request.Header.Get("Authorization") + key = strings.TrimPrefix(key, "Bearer ") + req.Header.Set("api-key", key) + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("Connection", c.Request.Header.Get("Connection")) + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusOK) + } + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusOK) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusOK) + } + var textResponse TextResponse + isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + var streamResponseText string + + defer func() { + if consumeQuota { + quota := 0 + completionRatio := 1.34 // default for gpt-3 + if strings.HasPrefix(textRequest.Model, "gpt-4") { + completionRatio = 2 + } + if isStream { + responseTokens := countTokenText(streamResponseText, textRequest.Model) + quota = promptTokens + int(float64(responseTokens)*completionRatio) + } else { + quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio) + } + quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("Error consuming token remain quota: " + err.Error()) + } + tokenName := c.GetString("token_name") + userId := c.GetInt("id") + model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %d 点额度(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, quota, modelRatio, groupRatio)) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + }() + + if isStream { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + if i := strings.Index(string(data), "\n\n"); i >= 0 { + return i + 2, data[0:i], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // must be something wrong! + common.SysError("Invalid stream response: " + data) + continue + } + dataChan <- data + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + switch relayMode { + case RelayModeChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("Error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Delta.Content + } + case RelayModeCompletions: + var streamResponse CompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("Error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Text + } + } + } + } + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + if strings.HasPrefix(data, "data: [DONE]") { + data = data[:12] + } + c.Render(-1, common.CustomEvent{Data: data}) + return true + case <-stopChan: + return false + } + }) + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusOK) + } + return nil + } else { + if consumeQuota { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusOK) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusOK) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK) + } + if textResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + } + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + } + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the client will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusOK) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusOK) + } + return nil + } +} diff --git a/controller/relay-utils.go b/controller/relay-utils.go index a2dc2685..0c0df970 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -77,3 +77,15 @@ func countTokenText(text string, model string) int { token := tokenEncoder.Encode(text, nil, nil) return len(token) } + +func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { + openAIError := OpenAIError{ + Message: err.Error(), + Type: "one_api_error", + Code: code, + } + return &OpenAIErrorWithStatusCode{ + OpenAIError: openAIError, + StatusCode: statusCode, + } +} diff --git a/controller/relay.go b/controller/relay.go index 6e248d67..a9a5a364 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,15 +1,10 @@ package controller import ( - "bufio" - "bytes" - "encoding/json" "fmt" "github.com/gin-gonic/gin" - "io" "net/http" "one-api/common" - "one-api/model" "strings" ) @@ -25,6 +20,7 @@ const ( RelayModeCompletions RelayModeEmbeddings RelayModeModeration + RelayModeImagesGenerations ) // https://platform.openai.com/docs/api-reference/chat @@ -104,8 +100,16 @@ func Relay(c *gin.Context) { relayMode = RelayModeEmbeddings } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { relayMode = RelayModeModeration + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + relayMode = RelayModeImagesGenerations + } + var err *OpenAIErrorWithStatusCode + switch relayMode { + case RelayModeImagesGenerations: + err = relayImageHelper(c, relayMode) + default: + err = relayHelper(c, relayMode) } - err := relayHelper(c, relayMode) if err != nil { if err.StatusCode == http.StatusTooManyRequests { err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" @@ -124,276 +128,6 @@ func Relay(c *gin.Context) { } } -func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { - openAIError := OpenAIError{ - Message: err.Error(), - Type: "one_api_error", - Code: code, - } - return &OpenAIErrorWithStatusCode{ - OpenAIError: openAIError, - StatusCode: statusCode, - } -} - -func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - channelType := c.GetInt("channel") - tokenId := c.GetInt("token_id") - consumeQuota := c.GetBool("consume_quota") - group := c.GetString("group") - var textRequest GeneralOpenAIRequest - if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - } - if relayMode == RelayModeModeration && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if channelType == common.ChannelTypeCustom { - baseURL = c.GetString("base_url") - } else if channelType == common.ChannelTypeOpenAI { - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - requestURL := strings.Split(requestURL, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) - baseURL = c.GetString("base_url") - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := textRequest.Model - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) - } else if channelType == common.ChannelTypePaLM { - err := relayPaLM(textRequest, c) - return err - } else { - // 强制使用 0613 模型 - textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0301") - textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0314") - textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0613") - textRequest.Model = textRequest.Model + "-0613" - } - var promptTokens int - switch relayMode { - case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) - case RelayModeCompletions: - promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) - case RelayModeModeration: - promptTokens = countTokenInput(textRequest.Input, textRequest.Model) - } - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens - } - modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) - if consumeQuota { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK) - } - } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusOK) - } - if channelType == common.ChannelTypeAzure { - key := c.Request.Header.Get("Authorization") - key = strings.TrimPrefix(key, "Bearer ") - req.Header.Set("api-key", key) - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - req.Header.Set("Connection", c.Request.Header.Get("Connection")) - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusOK) - } - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusOK) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusOK) - } - var textResponse TextResponse - isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - var streamResponseText string - - defer func() { - if consumeQuota { - quota := 0 - completionRatio := 1.34 // default for gpt-3 - if strings.HasPrefix(textRequest.Model, "gpt-4") { - completionRatio = 2 - } - if isStream { - responseTokens := countTokenText(streamResponseText, textRequest.Model) - quota = promptTokens + int(float64(responseTokens)*completionRatio) - } else { - quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio) - } - quota = int(float64(quota) * ratio) - if ratio != 0 && quota <= 0 { - quota = 1 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("Error consuming token remain quota: " + err.Error()) - } - tokenName := c.GetString("token_name") - userId := c.GetInt("id") - model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %d 点额度(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, quota, modelRatio, groupRatio)) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - - if isStream { - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - - if i := strings.Index(string(data), "\n\n"); i >= 0 { - return i + 2, data[0:i], nil - } - - if atEOF { - return len(data), data, nil - } - - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // must be something wrong! - common.SysError("Invalid stream response: " + data) - continue - } - dataChan <- data - data = data[6:] - if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("Error unmarshalling stream response: " + err.Error()) - return - } - for _, choice := range streamResponse.Choices { - streamResponseText += choice.Delta.Content - } - case RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("Error unmarshalling stream response: " + err.Error()) - return - } - for _, choice := range streamResponse.Choices { - streamResponseText += choice.Text - } - } - } - } - stopChan <- true - }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false - } - }) - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusOK) - } - return nil - } else { - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusOK) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusOK) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK) - } - if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, - } - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the client will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusOK) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusOK) - } - return nil - } -} - func RelayNotImplemented(c *gin.Context) { err := OpenAIError{ Message: "API not implemented", diff --git a/model/token.go b/model/token.go index 8ce252b2..64e52dcd 100644 --- a/model/token.go +++ b/model/token.go @@ -28,7 +28,7 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { } func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) { - err = DB.Where("user_id = ?", userId).Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&tokens).Error + err = DB.Where("user_id = ?", userId).Where("name LIKE ?", keyword+"%").Find(&tokens).Error return tokens, err } diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js new file mode 100644 index 00000000..3f3a4ab0 --- /dev/null +++ b/web/src/components/OperationSetting.js @@ -0,0 +1,282 @@ +import React, { useEffect, useState } from 'react'; +import { Divider, Form, Grid, Header } from 'semantic-ui-react'; +import { API, showError, verifyJSON } from '../helpers'; + +const OperationSetting = () => { + let [inputs, setInputs] = useState({ + QuotaForNewUser: 0, + QuotaForInviter: 0, + QuotaForInvitee: 0, + QuotaRemindThreshold: 0, + PreConsumedQuota: 0, + ModelRatio: '', + GroupRatio: '', + TopUpLink: '', + ChatLink: '', + AutomaticDisableChannelEnabled: '', + ChannelDisableThreshold: 0, + LogConsumeEnabled: '' + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + newInputs[item.key] = item.value; + }); + setInputs(newInputs); + setOriginInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + if (key.endsWith('Enabled')) { + value = inputs[key] === 'true' ? 'false' : 'true'; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name.endsWith('Enabled')) { + await updateOption(name, value); + } else { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + }; + + const submitConfig = async (group) => { + switch (group) { + case 'monitor': + if (originInputs['AutomaticDisableChannelEnabled'] !== inputs.AutomaticDisableChannelEnabled) { + await updateOption('AutomaticDisableChannelEnabled', inputs.AutomaticDisableChannelEnabled); + } + if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { + await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); + } + if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { + await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); + } + break; + case 'ratio': + if (originInputs['ModelRatio'] !== inputs.ModelRatio) { + if (!verifyJSON(inputs.ModelRatio)) { + showError('模型倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelRatio', inputs.ModelRatio); + } + if (originInputs['GroupRatio'] !== inputs.GroupRatio) { + if (!verifyJSON(inputs.GroupRatio)) { + showError('分组倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('GroupRatio', inputs.GroupRatio); + } + break; + case 'quota': + if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { + await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); + } + if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { + await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); + } + if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { + await updateOption('QuotaForInviter', inputs.QuotaForInviter); + } + if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { + await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); + } + break; + case 'general': + if (originInputs['TopUpLink'] !== inputs.TopUpLink) { + await updateOption('TopUpLink', inputs.TopUpLink); + } + if (originInputs['ChatLink'] !== inputs.ChatLink) { + await updateOption('ChatLink', inputs.ChatLink); + } + break; + } + }; + + return ( + + +
+
+ 通用设置 +
+ + + + + { + submitConfig('general').then(); + }}>保存通用设置 + +
+ 监控设置 +
+ + + + + + + + { + submitConfig('monitor').then(); + }}>保存监控设置 + +
+ 额度设置 +
+ + + + + + + { + submitConfig('quota').then(); + }}>保存额度设置 + +
+ 倍率设置 +
+ + + + + + + + { + submitConfig('ratio').then(); + }}>保存倍率设置 + +
+
+ ); +}; + +export default OperationSetting; diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index 786e935b..658e5294 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -26,18 +26,6 @@ const SystemSetting = () => { TurnstileSiteKey: '', TurnstileSecretKey: '', RegisterEnabled: '', - QuotaForNewUser: 0, - QuotaForInviter: 0, - QuotaForInvitee: 0, - QuotaRemindThreshold: 0, - PreConsumedQuota: 0, - ModelRatio: '', - GroupRatio: '', - TopUpLink: '', - ChatLink: '', - AutomaticDisableChannelEnabled: '', - ChannelDisableThreshold: 0, - LogConsumeEnabled: '' }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -71,8 +59,6 @@ const SystemSetting = () => { case 'WeChatAuthEnabled': case 'TurnstileCheckEnabled': case 'RegisterEnabled': - case 'AutomaticDisableChannelEnabled': - case 'LogConsumeEnabled': value = inputs[key] === 'true' ? 'false' : 'true'; break; default: @@ -102,16 +88,7 @@ const SystemSetting = () => { name === 'WeChatServerToken' || name === 'WeChatAccountQRCodeImageURL' || name === 'TurnstileSiteKey' || - name === 'TurnstileSecretKey' || - name === 'QuotaForNewUser' || - name === 'QuotaForInviter' || - name === 'QuotaForInvitee' || - name === 'QuotaRemindThreshold' || - name === 'PreConsumedQuota' || - name === 'ModelRatio' || - name === 'GroupRatio' || - name === 'TopUpLink' || - name === 'ChatLink' + name === 'TurnstileSecretKey' ) { setInputs((inputs) => ({ ...inputs, [name]: value })); } else { @@ -124,44 +101,6 @@ const SystemSetting = () => { await updateOption('ServerAddress', ServerAddress); }; - const submitOperationConfig = async () => { - if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { - await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); - } - if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { - await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); - } - if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { - await updateOption('QuotaForInviter', inputs.QuotaForInviter); - } - if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { - await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); - } - if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { - await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); - } - if (originInputs['ModelRatio'] !== inputs.ModelRatio) { - if (!verifyJSON(inputs.ModelRatio)) { - showError('模型倍率不是合法的 JSON 字符串'); - return; - } - await updateOption('ModelRatio', inputs.ModelRatio); - } - if (originInputs['GroupRatio'] !== inputs.GroupRatio) { - if (!verifyJSON(inputs.GroupRatio)) { - showError('分组倍率不是合法的 JSON 字符串'); - return; - } - await updateOption('GroupRatio', inputs.GroupRatio); - } - if (originInputs['TopUpLink'] !== inputs.TopUpLink) { - await updateOption('TopUpLink', inputs.TopUpLink); - } - if (originInputs['ChatLink'] !== inputs.ChatLink) { - await updateOption('ChatLink', inputs.ChatLink); - } - }; - const submitSMTP = async () => { if (originInputs['SMTPServer'] !== inputs.SMTPServer) { await updateOption('SMTPServer', inputs.SMTPServer); @@ -300,135 +239,6 @@ const SystemSetting = () => { /> -
- 运营设置 -
- - - - - - - - - - - - - - - - - - - 保存运营设置 - -
- 监控设置 -
- - - - - - -
配置 SMTP 用以支持系统的邮件发送 diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index b63a2b00..d5d735d7 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -154,7 +154,7 @@ const TokensTable = () => { icon='search' fluid iconPosition='left' - placeholder='搜索令牌的 ID 和名称 ...' + placeholder='搜索令牌的名称 ...' value={searchKeyword} loading={searching} onChange={handleKeywordChange} diff --git a/web/src/pages/Setting/index.js b/web/src/pages/Setting/index.js index 392e2ca7..30d0ef28 100644 --- a/web/src/pages/Setting/index.js +++ b/web/src/pages/Setting/index.js @@ -4,6 +4,7 @@ import SystemSetting from '../../components/SystemSetting'; import { isRoot } from '../../helpers'; import OtherSetting from '../../components/OtherSetting'; import PersonalSetting from '../../components/PersonalSetting'; +import OperationSetting from '../../components/OperationSetting'; const Setting = () => { let panes = [ @@ -18,6 +19,14 @@ const Setting = () => { ]; if (isRoot()) { + panes.push({ + menuItem: '运营设置', + render: () => ( + + + + ) + }); panes.push({ menuItem: '系统设置', render: () => (