diff --git a/README.md b/README.md index 8be12fa4..7fd38ec3 100644 --- a/README.md +++ b/README.md @@ -87,16 +87,19 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 12. 支持以美元为单位显示额度。 13. 支持发布公告,设置充值链接,设置新用户初始额度。 14. 支持模型映射,重定向用户的请求模型。 -15. 支持丰富的**自定义**设置, +15. 支持失败自动重试。 +16. 支持绘图接口。 +17. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -16. 支持通过系统访问令牌访问管理 API。 -17. 支持 Cloudflare Turnstile 用户校验。 -18. 支持用户管理,支持**多种用户登录注册方式**: +18. 支持通过系统访问令牌访问管理 API。 +19. 支持 Cloudflare Turnstile 用户校验。 +20. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册以及通过邮箱进行密码重置。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 -19. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 +21. 支持 [ChatGLM](https://github.com/THUDM/ChatGLM2-6B)。 +22. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 ## 部署 ### 基于 Docker 进行部署 diff --git a/common/constants.go b/common/constants.go index 7d505141..5d1adbd7 100644 --- a/common/constants.go +++ b/common/constants.go @@ -72,6 +72,7 @@ var AutomaticDisableChannelEnabled = false var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 var ApproximateTokenEnabled = false +var RetryTimes = 0 var RootUserEmail = "" diff --git a/controller/model.go b/controller/model.go index 2be935d6..5d7becb7 100644 --- a/controller/model.go +++ b/controller/model.go @@ -252,6 +252,24 @@ func init() { Root: "code-davinci-edit-001", Parent: nil, }, + { + Id: "ChatGLM", + Object: "model", + Created: 1677649963, + OwnedBy: "thudm", + Permission: permission, + Root: "ChatGLM", + Parent: nil, + }, + { + Id: "ChatGLM2", + Object: "model", + Created: 1677649963, + OwnedBy: "thudm", + Permission: permission, + Root: "ChatGLM2", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-image.go b/controller/relay-image.go index e0483d56..7a37be80 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -22,26 +22,26 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") - var textRequest GeneralOpenAIRequest + var imageRequest ImageRequest if consumeQuota { - err := common.UnmarshalBodyReusable(c, &textRequest) + err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } } // Prompt validation - if textRequest.Prompt == "" { + if imageRequest.Prompt == "" { return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) } // Not "256x256", "512x512", or "1024x1024" - if textRequest.Size != "" && textRequest.Size != "256x256" && textRequest.Size != "512x512" && textRequest.Size != "1024x1024" { + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) } - // N should between 1 to 10 - if textRequest.N != 0 && (textRequest.N < 1 || textRequest.N > 10) { + // N should between 1 and 10 + if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) } @@ -71,7 +71,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var requestBody io.Reader if isModelMapped { - jsonStr, err := json.Marshal(textRequest) + jsonStr, err := json.Marshal(imageRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } @@ -87,14 +87,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode sizeRatio := 1.0 // Size - if textRequest.Size == "256x256" { + if imageRequest.Size == "256x256" { sizeRatio = 1 - } else if textRequest.Size == "512x512" { + } else if imageRequest.Size == "512x512" { sizeRatio = 1.125 - } else if textRequest.Size == "1024x1024" { + } else if imageRequest.Size == "1024x1024" { sizeRatio = 1.25 } - quota := int(ratio * sizeRatio * 1000) + quota := int(ratio*sizeRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) diff --git a/controller/relay-text.go b/controller/relay-text.go index 0f2472f7..d0c06100 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -33,6 +33,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } + if relayMode == RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } // request validation if textRequest.Model == "" { return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) @@ -478,7 +481,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if strings.HasPrefix(data, "data: [DONE]") { data = data[:12] } - log.Print(data) + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") c.Render(-1, common.CustomEvent{Data: data}) return true case <-stopChan: diff --git a/controller/relay.go b/controller/relay.go index 84d7f7bd..13179e6c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "strconv" "strings" "github.com/gin-gonic/gin" @@ -59,6 +60,12 @@ type TextRequest struct { //Stream bool `json:"stream"` } +type ImageRequest struct { + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -139,6 +146,8 @@ func Relay(c *gin.Context) { relayMode = RelayModeCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { relayMode = RelayModeEmbeddings + } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + relayMode = RelayModeEmbeddings } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { relayMode = RelayModeModerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { @@ -154,16 +163,25 @@ func Relay(c *gin.Context) { err = relayTextHelper(c, relayMode) } if err != nil { - if err.StatusCode == http.StatusTooManyRequests { - err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + retryTimesStr := c.Query("retry") + retryTimes, _ := strconv.Atoi(retryTimesStr) + if retryTimesStr == "" { + retryTimes = common.RetryTimes + } + if retryTimes > 0 { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) + } else { + if err.StatusCode == http.StatusTooManyRequests { + err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + } + c.JSON(err.StatusCode, gin.H{ + "error": err.OpenAIError, + }) } - c.JSON(err.StatusCode, gin.H{ - "error": err.OpenAIError, - }) channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") { + if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) diff --git a/middleware/distributor.go b/middleware/distributor.go index 8eb17e1c..2b6ccbf9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -74,6 +74,11 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "text-moderation-stable" } } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if modelRequest.Model == "" { modelRequest.Model = "dall-e" diff --git a/model/option.go b/model/option.go index 4bb1425d..dd2e563e 100644 --- a/model/option.go +++ b/model/option.go @@ -71,6 +71,7 @@ func InitOptionMap() { common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) + common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -205,6 +206,8 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaRemindThreshold, _ = strconv.Atoi(value) case "PreConsumedQuota": common.PreConsumedQuota, _ = strconv.Atoi(value) + case "RetryTimes": + common.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": diff --git a/router/relay-router.go b/router/relay-router.go index 9c480c41..0c8e9415 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -25,6 +25,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/images/edits", controller.RelayNotImplemented) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) + relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) relayV1Router.GET("/files", controller.RelayNotImplemented) diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 69100c85..2adc7fa4 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -20,6 +20,7 @@ const OperationSetting = () => { DisplayInCurrencyEnabled: '', DisplayTokenStatEnabled: '', ApproximateTokenEnabled: '', + RetryTimes: 0, }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -122,6 +123,9 @@ const OperationSetting = () => { if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); } + if (originInputs['RetryTimes'] !== inputs.RetryTimes) { + await updateOption('RetryTimes', inputs.RetryTimes); + } break; } }; @@ -133,7 +137,7 @@ const OperationSetting = () => {
通用设置
- + { step='0.01' placeholder='一单位货币能兑换的额度' /> + { const [groupOptions, setGroupOptions] = useState([]); const [basicModels, setBasicModels] = useState([]); const [fullModels, setFullModels] = useState([]); + const [customModel, setCustomModel] = useState(''); const handleInputChange = (e, { name, value }) => { console.log(name, value) setInputs((inputs) => ({ ...inputs, [name]: value })); @@ -45,6 +46,19 @@ const EditChannel = () => { data.models = []; } else { data.models = data.models.split(','); + setTimeout(() => { + let localModelOptions = [...modelOptions]; + data.models.forEach((model) => { + if (!localModelOptions.find((option) => option.key === model)) { + localModelOptions.push({ + key: model, + text: model, + value: model + }); + } + }); + setModelOptions(localModelOptions); + }, 1000); } if (data.group === '') { data.groups = []; @@ -265,6 +279,27 @@ const EditChannel = () => { + { + let localModels = [...inputs.models]; + localModels.push(customModel); + let localModelOptions = [...modelOptions]; + localModelOptions.push({ + key: customModel, + text: customModel, + value: customModel, + }); + setModelOptions(localModelOptions); + handleInputChange(null, { name: 'models', value: localModels }); + }}>填入 + } + placeholder='输入自定义模型名称' + value={customModel} + onChange={(e, { value }) => { + setCustomModel(value); + }} + /> { /> ) } - +