diff --git a/README.md b/README.md
index bcf8b664..18835948 100644
--- a/README.md
+++ b/README.md
@@ -70,17 +70,18 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
9. 支持渠道**设置模型列表**。
10. 支持**查看额度明细**。
-11. 支持发布公告,设置充值链接,设置新用户初始额度。
-12. 支持丰富的**自定义**设置,
+11. 支持**用户邀请奖励**。
+12. 支持发布公告,设置充值链接,设置新用户初始额度。
+13. 支持丰富的**自定义**设置,
1. 支持自定义系统名称,logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
-13. 支持通过系统访问令牌访问管理 API。
-14. 支持 Cloudflare Turnstile 用户校验。
-15. 支持用户管理,支持**多种用户登录注册方式**:
+14. 支持通过系统访问令牌访问管理 API。
+15. 支持 Cloudflare Turnstile 用户校验。
+16. 支持用户管理,支持**多种用户登录注册方式**:
+ 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
-16. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
+17. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署
### 基于 Docker 进行部署
@@ -157,9 +158,9 @@ sudo service nginx restart
### 宝塔部署教程
-详见[#175](https://github.com/songquanpeng/one-api/issues/175)。
+详见 [#175](https://github.com/songquanpeng/one-api/issues/175)。
-如果部署后访问出现空白页面,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。
+如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。
### 部署第三方服务配合 One API 使用
> 欢迎 PR 添加更多示例。
@@ -277,7 +278,7 @@ https://openai.justsong.cn
+ 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。
## 注意
-本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及法律法规的情况下使用,不得用于非法用途。
+本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
本项目使用 MIT 协议进行开源,请以某种方式保留 One API 的版权信息。
diff --git a/controller/relay-image.go b/controller/relay-image.go
new file mode 100644
index 00000000..c5311272
--- /dev/null
+++ b/controller/relay-image.go
@@ -0,0 +1,34 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+)
+
+func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+ // TODO: this part is not finished
+ req, err := http.NewRequest(c.Request.Method, c.Request.RequestURI, c.Request.Body)
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return errorWrapper(err, "do_request_failed", http.StatusOK)
+ }
+ err = req.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_request_body_failed", http.StatusOK)
+ }
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_response_body_failed", http.StatusOK)
+ }
+ return nil
+}
diff --git a/controller/relay-text.go b/controller/relay-text.go
new file mode 100644
index 00000000..2a07f2fa
--- /dev/null
+++ b/controller/relay-text.go
@@ -0,0 +1,266 @@
+package controller
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strings"
+)
+
+func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+ channelType := c.GetInt("channel")
+ tokenId := c.GetInt("token_id")
+ consumeQuota := c.GetBool("consume_quota")
+ group := c.GetString("group")
+ var textRequest GeneralOpenAIRequest
+ if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
+ err := common.UnmarshalBodyReusable(c, &textRequest)
+ if err != nil {
+ return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+ }
+ }
+ if relayMode == RelayModeModeration && textRequest.Model == "" {
+ textRequest.Model = "text-moderation-latest"
+ }
+ baseURL := common.ChannelBaseURLs[channelType]
+ requestURL := c.Request.URL.String()
+ if channelType == common.ChannelTypeCustom {
+ baseURL = c.GetString("base_url")
+ } else if channelType == common.ChannelTypeOpenAI {
+ if c.GetString("base_url") != "" {
+ baseURL = c.GetString("base_url")
+ }
+ }
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+ if channelType == common.ChannelTypeAzure {
+ // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
+ query := c.Request.URL.Query()
+ apiVersion := query.Get("api-version")
+ if apiVersion == "" {
+ apiVersion = c.GetString("api_version")
+ }
+ requestURL := strings.Split(requestURL, "?")[0]
+ requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
+ baseURL = c.GetString("base_url")
+ task := strings.TrimPrefix(requestURL, "/v1/")
+ model_ := textRequest.Model
+ model_ = strings.Replace(model_, ".", "", -1)
+ // https://github.com/songquanpeng/one-api/issues/67
+ model_ = strings.TrimSuffix(model_, "-0301")
+ model_ = strings.TrimSuffix(model_, "-0314")
+ model_ = strings.TrimSuffix(model_, "-0613")
+ fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
+ } else if channelType == common.ChannelTypePaLM {
+ err := relayPaLM(textRequest, c)
+ return err
+ }
+ var promptTokens int
+ switch relayMode {
+ case RelayModeChatCompletions:
+ promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
+ case RelayModeCompletions:
+ promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
+ case RelayModeModeration:
+ promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
+ }
+ preConsumedTokens := common.PreConsumedQuota
+ if textRequest.MaxTokens != 0 {
+ preConsumedTokens = promptTokens + textRequest.MaxTokens
+ }
+ modelRatio := common.GetModelRatio(textRequest.Model)
+ groupRatio := common.GetGroupRatio(group)
+ ratio := modelRatio * groupRatio
+ preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+ if consumeQuota {
+ err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+ if err != nil {
+ return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
+ }
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
+ if err != nil {
+ return errorWrapper(err, "new_request_failed", http.StatusOK)
+ }
+ if channelType == common.ChannelTypeAzure {
+ key := c.Request.Header.Get("Authorization")
+ key = strings.TrimPrefix(key, "Bearer ")
+ req.Header.Set("api-key", key)
+ } else {
+ req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+ }
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ req.Header.Set("Connection", c.Request.Header.Get("Connection"))
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return errorWrapper(err, "do_request_failed", http.StatusOK)
+ }
+ err = req.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_request_body_failed", http.StatusOK)
+ }
+ err = c.Request.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_request_body_failed", http.StatusOK)
+ }
+ var textResponse TextResponse
+ isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
+ var streamResponseText string
+
+ defer func() {
+ if consumeQuota {
+ quota := 0
+ completionRatio := 1.34 // default for gpt-3
+ if strings.HasPrefix(textRequest.Model, "gpt-4") {
+ completionRatio = 2
+ }
+ if isStream {
+ responseTokens := countTokenText(streamResponseText, textRequest.Model)
+ quota = promptTokens + int(float64(responseTokens)*completionRatio)
+ } else {
+ quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio)
+ }
+ quota = int(float64(quota) * ratio)
+ if ratio != 0 && quota <= 0 {
+ quota = 1
+ }
+ quotaDelta := quota - preConsumedQuota
+ err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
+ if err != nil {
+ common.SysError("Error consuming token remain quota: " + err.Error())
+ }
+ tokenName := c.GetString("token_name")
+ userId := c.GetInt("id")
+ model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %d 点额度(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, quota, modelRatio, groupRatio))
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
+ }
+ }()
+
+ if isStream {
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if atEOF && len(data) == 0 {
+ return 0, nil, nil
+ }
+
+ if i := strings.Index(string(data), "\n\n"); i >= 0 {
+ return i + 2, data[0:i], nil
+ }
+
+ if atEOF {
+ return len(data), data, nil
+ }
+
+ return 0, nil, nil
+ })
+ dataChan := make(chan string)
+ stopChan := make(chan bool)
+ go func() {
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 6 { // must be something wrong!
+ common.SysError("Invalid stream response: " + data)
+ continue
+ }
+ dataChan <- data
+ data = data[6:]
+ if !strings.HasPrefix(data, "[DONE]") {
+ switch relayMode {
+ case RelayModeChatCompletions:
+ var streamResponse ChatCompletionsStreamResponse
+ err = json.Unmarshal([]byte(data), &streamResponse)
+ if err != nil {
+ common.SysError("Error unmarshalling stream response: " + err.Error())
+ return
+ }
+ for _, choice := range streamResponse.Choices {
+ streamResponseText += choice.Delta.Content
+ }
+ case RelayModeCompletions:
+ var streamResponse CompletionsStreamResponse
+ err = json.Unmarshal([]byte(data), &streamResponse)
+ if err != nil {
+ common.SysError("Error unmarshalling stream response: " + err.Error())
+ return
+ }
+ for _, choice := range streamResponse.Choices {
+ streamResponseText += choice.Text
+ }
+ }
+ }
+ }
+ stopChan <- true
+ }()
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("Transfer-Encoding", "chunked")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case data := <-dataChan:
+ if strings.HasPrefix(data, "data: [DONE]") {
+ data = data[:12]
+ }
+ c.Render(-1, common.CustomEvent{Data: data})
+ return true
+ case <-stopChan:
+ return false
+ }
+ })
+ err = resp.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_response_body_failed", http.StatusOK)
+ }
+ return nil
+ } else {
+ if consumeQuota {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return errorWrapper(err, "read_response_body_failed", http.StatusOK)
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_response_body_failed", http.StatusOK)
+ }
+ err = json.Unmarshal(responseBody, &textResponse)
+ if err != nil {
+ return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
+ }
+ if textResponse.Error.Type != "" {
+ return &OpenAIErrorWithStatusCode{
+ OpenAIError: textResponse.Error,
+ StatusCode: resp.StatusCode,
+ }
+ }
+ // Reset response body
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ }
+ // We shouldn't set the header before we parse the response body, because the parse part may fail.
+ // And then we will have to send an error response, but in this case, the header has already been set.
+ // So the client will be confused by the response.
+ // For example, Postman will report error, and we cannot check the response at all.
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_response_body_failed", http.StatusOK)
+ }
+ return nil
+ }
+}
diff --git a/controller/relay-utils.go b/controller/relay-utils.go
index a2dc2685..0c0df970 100644
--- a/controller/relay-utils.go
+++ b/controller/relay-utils.go
@@ -77,3 +77,15 @@ func countTokenText(text string, model string) int {
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
}
+
+func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
+ openAIError := OpenAIError{
+ Message: err.Error(),
+ Type: "one_api_error",
+ Code: code,
+ }
+ return &OpenAIErrorWithStatusCode{
+ OpenAIError: openAIError,
+ StatusCode: statusCode,
+ }
+}
diff --git a/controller/relay.go b/controller/relay.go
index 6e248d67..a9a5a364 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -1,15 +1,10 @@
package controller
import (
- "bufio"
- "bytes"
- "encoding/json"
"fmt"
"github.com/gin-gonic/gin"
- "io"
"net/http"
"one-api/common"
- "one-api/model"
"strings"
)
@@ -25,6 +20,7 @@ const (
RelayModeCompletions
RelayModeEmbeddings
RelayModeModeration
+ RelayModeImagesGenerations
)
// https://platform.openai.com/docs/api-reference/chat
@@ -104,8 +100,16 @@ func Relay(c *gin.Context) {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModeration
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
+ relayMode = RelayModeImagesGenerations
+ }
+ var err *OpenAIErrorWithStatusCode
+ switch relayMode {
+ case RelayModeImagesGenerations:
+ err = relayImageHelper(c, relayMode)
+ default:
+ err = relayHelper(c, relayMode)
}
- err := relayHelper(c, relayMode)
if err != nil {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
@@ -124,276 +128,6 @@ func Relay(c *gin.Context) {
}
}
-func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
- openAIError := OpenAIError{
- Message: err.Error(),
- Type: "one_api_error",
- Code: code,
- }
- return &OpenAIErrorWithStatusCode{
- OpenAIError: openAIError,
- StatusCode: statusCode,
- }
-}
-
-func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
- channelType := c.GetInt("channel")
- tokenId := c.GetInt("token_id")
- consumeQuota := c.GetBool("consume_quota")
- group := c.GetString("group")
- var textRequest GeneralOpenAIRequest
- if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
- err := common.UnmarshalBodyReusable(c, &textRequest)
- if err != nil {
- return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
- }
- }
- if relayMode == RelayModeModeration && textRequest.Model == "" {
- textRequest.Model = "text-moderation-latest"
- }
- baseURL := common.ChannelBaseURLs[channelType]
- requestURL := c.Request.URL.String()
- if channelType == common.ChannelTypeCustom {
- baseURL = c.GetString("base_url")
- } else if channelType == common.ChannelTypeOpenAI {
- if c.GetString("base_url") != "" {
- baseURL = c.GetString("base_url")
- }
- }
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- if channelType == common.ChannelTypeAzure {
- // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
- query := c.Request.URL.Query()
- apiVersion := query.Get("api-version")
- if apiVersion == "" {
- apiVersion = c.GetString("api_version")
- }
- requestURL := strings.Split(requestURL, "?")[0]
- requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
- baseURL = c.GetString("base_url")
- task := strings.TrimPrefix(requestURL, "/v1/")
- model_ := textRequest.Model
- model_ = strings.Replace(model_, ".", "", -1)
- // https://github.com/songquanpeng/one-api/issues/67
- model_ = strings.TrimSuffix(model_, "-0301")
- model_ = strings.TrimSuffix(model_, "-0314")
- model_ = strings.TrimSuffix(model_, "-0613")
- fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
- } else if channelType == common.ChannelTypePaLM {
- err := relayPaLM(textRequest, c)
- return err
- } else {
- // 强制使用 0613 模型
- textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0301")
- textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0314")
- textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0613")
- textRequest.Model = textRequest.Model + "-0613"
- }
- var promptTokens int
- switch relayMode {
- case RelayModeChatCompletions:
- promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
- case RelayModeCompletions:
- promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
- case RelayModeModeration:
- promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
- }
- preConsumedTokens := common.PreConsumedQuota
- if textRequest.MaxTokens != 0 {
- preConsumedTokens = promptTokens + textRequest.MaxTokens
- }
- modelRatio := common.GetModelRatio(textRequest.Model)
- groupRatio := common.GetGroupRatio(group)
- ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
- if consumeQuota {
- err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
- if err != nil {
- return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
- }
- }
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
- if err != nil {
- return errorWrapper(err, "new_request_failed", http.StatusOK)
- }
- if channelType == common.ChannelTypeAzure {
- key := c.Request.Header.Get("Authorization")
- key = strings.TrimPrefix(key, "Bearer ")
- req.Header.Set("api-key", key)
- } else {
- req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
- }
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- req.Header.Set("Connection", c.Request.Header.Get("Connection"))
- client := &http.Client{}
- resp, err := client.Do(req)
- if err != nil {
- return errorWrapper(err, "do_request_failed", http.StatusOK)
- }
- err = req.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusOK)
- }
- err = c.Request.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusOK)
- }
- var textResponse TextResponse
- isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
- var streamResponseText string
-
- defer func() {
- if consumeQuota {
- quota := 0
- completionRatio := 1.34 // default for gpt-3
- if strings.HasPrefix(textRequest.Model, "gpt-4") {
- completionRatio = 2
- }
- if isStream {
- responseTokens := countTokenText(streamResponseText, textRequest.Model)
- quota = promptTokens + int(float64(responseTokens)*completionRatio)
- } else {
- quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio)
- }
- quota = int(float64(quota) * ratio)
- if ratio != 0 && quota <= 0 {
- quota = 1
- }
- quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
- if err != nil {
- common.SysError("Error consuming token remain quota: " + err.Error())
- }
- tokenName := c.GetString("token_name")
- userId := c.GetInt("id")
- model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %d 点额度(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, quota, modelRatio, groupRatio))
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- channelId := c.GetInt("channel_id")
- model.UpdateChannelUsedQuota(channelId, quota)
- }
- }()
-
- if isStream {
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
-
- if i := strings.Index(string(data), "\n\n"); i >= 0 {
- return i + 2, data[0:i], nil
- }
-
- if atEOF {
- return len(data), data, nil
- }
-
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 6 { // must be something wrong!
- common.SysError("Invalid stream response: " + data)
- continue
- }
- dataChan <- data
- data = data[6:]
- if !strings.HasPrefix(data, "[DONE]") {
- switch relayMode {
- case RelayModeChatCompletions:
- var streamResponse ChatCompletionsStreamResponse
- err = json.Unmarshal([]byte(data), &streamResponse)
- if err != nil {
- common.SysError("Error unmarshalling stream response: " + err.Error())
- return
- }
- for _, choice := range streamResponse.Choices {
- streamResponseText += choice.Delta.Content
- }
- case RelayModeCompletions:
- var streamResponse CompletionsStreamResponse
- err = json.Unmarshal([]byte(data), &streamResponse)
- if err != nil {
- common.SysError("Error unmarshalling stream response: " + err.Error())
- return
- }
- for _, choice := range streamResponse.Choices {
- streamResponseText += choice.Text
- }
- }
- }
- }
- stopChan <- true
- }()
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("Transfer-Encoding", "chunked")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- if strings.HasPrefix(data, "data: [DONE]") {
- data = data[:12]
- }
- c.Render(-1, common.CustomEvent{Data: data})
- return true
- case <-stopChan:
- return false
- }
- })
- err = resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusOK)
- }
- return nil
- } else {
- if consumeQuota {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusOK)
- }
- err = resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusOK)
- }
- err = json.Unmarshal(responseBody, &textResponse)
- if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
- }
- if textResponse.Error.Type != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: textResponse.Error,
- StatusCode: resp.StatusCode,
- }
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- }
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the client will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
- }
- err = resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusOK)
- }
- return nil
- }
-}
-
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
Message: "API not implemented",
diff --git a/model/token.go b/model/token.go
index 8ce252b2..64e52dcd 100644
--- a/model/token.go
+++ b/model/token.go
@@ -28,7 +28,7 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
}
func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) {
- err = DB.Where("user_id = ?", userId).Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&tokens).Error
+ err = DB.Where("user_id = ?", userId).Where("name LIKE ?", keyword+"%").Find(&tokens).Error
return tokens, err
}
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
new file mode 100644
index 00000000..3f3a4ab0
--- /dev/null
+++ b/web/src/components/OperationSetting.js
@@ -0,0 +1,282 @@
+import React, { useEffect, useState } from 'react';
+import { Divider, Form, Grid, Header } from 'semantic-ui-react';
+import { API, showError, verifyJSON } from '../helpers';
+
+const OperationSetting = () => {
+ let [inputs, setInputs] = useState({
+ QuotaForNewUser: 0,
+ QuotaForInviter: 0,
+ QuotaForInvitee: 0,
+ QuotaRemindThreshold: 0,
+ PreConsumedQuota: 0,
+ ModelRatio: '',
+ GroupRatio: '',
+ TopUpLink: '',
+ ChatLink: '',
+ AutomaticDisableChannelEnabled: '',
+ ChannelDisableThreshold: 0,
+ LogConsumeEnabled: ''
+ });
+ const [originInputs, setOriginInputs] = useState({});
+ let [loading, setLoading] = useState(false);
+
+ const getOptions = async () => {
+ const res = await API.get('/api/option/');
+ const { success, message, data } = res.data;
+ if (success) {
+ let newInputs = {};
+ data.forEach((item) => {
+ if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
+ item.value = JSON.stringify(JSON.parse(item.value), null, 2);
+ }
+ newInputs[item.key] = item.value;
+ });
+ setInputs(newInputs);
+ setOriginInputs(newInputs);
+ } else {
+ showError(message);
+ }
+ };
+
+ useEffect(() => {
+ getOptions().then();
+ }, []);
+
+ const updateOption = async (key, value) => {
+ setLoading(true);
+ if (key.endsWith('Enabled')) {
+ value = inputs[key] === 'true' ? 'false' : 'true';
+ }
+ const res = await API.put('/api/option/', {
+ key,
+ value
+ });
+ const { success, message } = res.data;
+ if (success) {
+ setInputs((inputs) => ({ ...inputs, [key]: value }));
+ } else {
+ showError(message);
+ }
+ setLoading(false);
+ };
+
+ const handleInputChange = async (e, { name, value }) => {
+ if (name.endsWith('Enabled')) {
+ await updateOption(name, value);
+ } else {
+ setInputs((inputs) => ({ ...inputs, [name]: value }));
+ }
+ };
+
+ const submitConfig = async (group) => {
+ switch (group) {
+ case 'monitor':
+ if (originInputs['AutomaticDisableChannelEnabled'] !== inputs.AutomaticDisableChannelEnabled) {
+ await updateOption('AutomaticDisableChannelEnabled', inputs.AutomaticDisableChannelEnabled);
+ }
+ if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
+ await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
+ }
+ if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
+ await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
+ }
+ break;
+ case 'ratio':
+ if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
+ if (!verifyJSON(inputs.ModelRatio)) {
+ showError('模型倍率不是合法的 JSON 字符串');
+ return;
+ }
+ await updateOption('ModelRatio', inputs.ModelRatio);
+ }
+ if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
+ if (!verifyJSON(inputs.GroupRatio)) {
+ showError('分组倍率不是合法的 JSON 字符串');
+ return;
+ }
+ await updateOption('GroupRatio', inputs.GroupRatio);
+ }
+ break;
+ case 'quota':
+ if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
+ await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
+ }
+ if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
+ await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
+ }
+ if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
+ await updateOption('QuotaForInviter', inputs.QuotaForInviter);
+ }
+ if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
+ await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
+ }
+ break;
+ case 'general':
+ if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
+ await updateOption('TopUpLink', inputs.TopUpLink);
+ }
+ if (originInputs['ChatLink'] !== inputs.ChatLink) {
+ await updateOption('ChatLink', inputs.ChatLink);
+ }
+ break;
+ }
+ };
+
+ return (
+