diff --git a/README.md b/README.md
index 8be12fa4..7fd38ec3 100644
--- a/README.md
+++ b/README.md
@@ -87,16 +87,19 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
12. 支持以美元为单位显示额度。
13. 支持发布公告,设置充值链接,设置新用户初始额度。
14. 支持模型映射,重定向用户的请求模型。
-15. 支持丰富的**自定义**设置,
+15. 支持失败自动重试。
+16. 支持绘图接口。
+17. 支持丰富的**自定义**设置,
1. 支持自定义系统名称,logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
-16. 支持通过系统访问令牌访问管理 API。
-17. 支持 Cloudflare Turnstile 用户校验。
-18. 支持用户管理,支持**多种用户登录注册方式**:
+18. 支持通过系统访问令牌访问管理 API。
+19. 支持 Cloudflare Turnstile 用户校验。
+20. 支持用户管理,支持**多种用户登录注册方式**:
+ 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
-19. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
+21. 支持 [ChatGLM](https://github.com/THUDM/ChatGLM2-6B)。
+22. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署
### 基于 Docker 进行部署
diff --git a/common/constants.go b/common/constants.go
index 7d505141..5d1adbd7 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -72,6 +72,7 @@ var AutomaticDisableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
+var RetryTimes = 0
var RootUserEmail = ""
diff --git a/controller/model.go b/controller/model.go
index 2be935d6..5d7becb7 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -252,6 +252,24 @@ func init() {
Root: "code-davinci-edit-001",
Parent: nil,
},
+ {
+ Id: "ChatGLM",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "thudm",
+ Permission: permission,
+ Root: "ChatGLM",
+ Parent: nil,
+ },
+ {
+ Id: "ChatGLM2",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "thudm",
+ Permission: permission,
+ Root: "ChatGLM2",
+ Parent: nil,
+ },
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
diff --git a/controller/relay-image.go b/controller/relay-image.go
index e0483d56..7a37be80 100644
--- a/controller/relay-image.go
+++ b/controller/relay-image.go
@@ -22,26 +22,26 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
- var textRequest GeneralOpenAIRequest
+ var imageRequest ImageRequest
if consumeQuota {
- err := common.UnmarshalBodyReusable(c, &textRequest)
+ err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
// Prompt validation
- if textRequest.Prompt == "" {
+ if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
- if textRequest.Size != "" && textRequest.Size != "256x256" && textRequest.Size != "512x512" && textRequest.Size != "1024x1024" {
+ if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
}
- // N should between 1 to 10
- if textRequest.N != 0 && (textRequest.N < 1 || textRequest.N > 10) {
+ // N should between 1 and 10
+ if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
}
@@ -71,7 +71,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var requestBody io.Reader
if isModelMapped {
- jsonStr, err := json.Marshal(textRequest)
+ jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
@@ -87,14 +87,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
sizeRatio := 1.0
// Size
- if textRequest.Size == "256x256" {
+ if imageRequest.Size == "256x256" {
sizeRatio = 1
- } else if textRequest.Size == "512x512" {
+ } else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
- } else if textRequest.Size == "1024x1024" {
+ } else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
}
- quota := int(ratio * sizeRatio * 1000)
+ quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
diff --git a/controller/relay-text.go b/controller/relay-text.go
index 0f2472f7..d0c06100 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -33,6 +33,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
+ if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
+ textRequest.Model = c.Param("model")
+ }
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
@@ -478,7 +481,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
- log.Print(data)
+ // some implementations may add \r at the end of data
+ data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
diff --git a/controller/relay.go b/controller/relay.go
index 84d7f7bd..13179e6c 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "strconv"
"strings"
"github.com/gin-gonic/gin"
@@ -59,6 +60,12 @@ type TextRequest struct {
//Stream bool `json:"stream"`
}
+type ImageRequest struct {
+ Prompt string `json:"prompt"`
+ N int `json:"n"`
+ Size string `json:"size"`
+}
+
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
@@ -139,6 +146,8 @@ func Relay(c *gin.Context) {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
+ } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+ relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
@@ -154,16 +163,25 @@ func Relay(c *gin.Context) {
err = relayTextHelper(c, relayMode)
}
if err != nil {
- if err.StatusCode == http.StatusTooManyRequests {
- err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+ retryTimesStr := c.Query("retry")
+ retryTimes, _ := strconv.Atoi(retryTimesStr)
+ if retryTimesStr == "" {
+ retryTimes = common.RetryTimes
+ }
+ if retryTimes > 0 {
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
+ } else {
+ if err.StatusCode == http.StatusTooManyRequests {
+ err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+ }
+ c.JSON(err.StatusCode, gin.H{
+ "error": err.OpenAIError,
+ })
}
- c.JSON(err.StatusCode, gin.H{
- "error": err.OpenAIError,
- })
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
- if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") {
+ if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 8eb17e1c..2b6ccbf9 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -74,6 +74,11 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "text-moderation-stable"
}
}
+ if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = c.Param("model")
+ }
+ }
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
diff --git a/model/option.go b/model/option.go
index 4bb1425d..dd2e563e 100644
--- a/model/option.go
+++ b/model/option.go
@@ -71,6 +71,7 @@ func InitOptionMap() {
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
+ common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -205,6 +206,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
+ case "RetryTimes":
+ common.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
diff --git a/router/relay-router.go b/router/relay-router.go
index 9c480c41..0c8e9415 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -25,6 +25,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.Relay)
+ relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
relayV1Router.GET("/files", controller.RelayNotImplemented)
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index 69100c85..2adc7fa4 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -20,6 +20,7 @@ const OperationSetting = () => {
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
+ RetryTimes: 0,
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@@ -122,6 +123,9 @@ const OperationSetting = () => {
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
}
+ if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
+ await updateOption('RetryTimes', inputs.RetryTimes);
+ }
break;
}
};
@@ -133,7 +137,7 @@ const OperationSetting = () => {