From d85e356b6e6a7b22750f1aa93c3024862e82c58f Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Fri, 24 Nov 2023 20:42:29 +0800 Subject: [PATCH 01/30] refactor: remove consumeQuota related logic (#738) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: 删除relay-text中的consumeQuota变量 该变量始终为true,可以删除 * chore: remove useless code --------- Co-authored-by: JustSong --- controller/relay-image.go | 71 +++++++++++++++++-------------------- controller/relay-openai.go | 45 ++++++++++++------------ controller/relay-text.go | 72 ++++++++++++++++++-------------------- middleware/auth.go | 6 ---- 4 files changed, 88 insertions(+), 106 deletions(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index 1d1b71ba..0ff18309 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var imageRequest ImageRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } // Size validation @@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*imageCostRatio*1000) * imageRequest.N - if consumeQuota && userQuota-quota < 0 { + if userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } @@ -151,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var textResponse ImageResponse defer func(ctx context.Context) { - if consumeQuota { - err := model.PostConsumeTokenQuota(tokenId, quota) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) + responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) diff --git a/controller/relay-openai.go b/controller/relay-openai.go index dcd20115..37867843 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O return nil, responseText } -func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { var textResponse TextResponse - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. @@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, err := io.Copy(c.Writer, resp.Body) + _, err = io.Copy(c.Writer, resp.Body) if err != nil { return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } diff --git a/controller/relay-text.go b/controller/relay-text.go index 018c8d8a..dd9e7153 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var textRequest GeneralOpenAIRequest - if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &textRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" @@ -235,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { preConsumedQuota = 0 common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } - if consumeQuota && preConsumedQuota > 0 { + if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) @@ -414,37 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { - if consumeQuota { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } + quota := 0 + completionRatio := common.GetCompletionRatio(textRequest.Model) + promptTokens = textResponse.Usage.PromptTokens + completionTokens = textResponse.Usage.CompletionTokens + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) + } + }() }(c.Request.Context()) switch apiType { @@ -458,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) return nil } else { - err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) + err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } diff --git a/middleware/auth.go b/middleware/auth.go index b0803612..ad7e64b7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) From b4d67ca6144eff90a2f8bf8f7f1b262af67d2e0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ShinChven=20=E2=9C=A8?= Date: Fri, 24 Nov 2023 20:52:59 +0800 Subject: [PATCH 02/30] fix: add Message-ID header for email (#732) * feat: Add Message-ID to email headers to comply with RFC 5322 - Extract domain from SMTPFrom - Generate a unique Message-ID - Add Message-ID to email headers * chore: check slice length --------- Co-authored-by: JustSong --- common/email.go | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/common/email.go b/common/email.go index 74f4cccd..7d6963cc 100644 --- a/common/email.go +++ b/common/email.go @@ -1,6 +1,7 @@ package common import ( + "crypto/rand" "crypto/tls" "encoding/base64" "fmt" @@ -13,15 +14,32 @@ func SendEmail(subject string, receiver string, content string) error { SMTPFrom = SMTPAccount } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) + + // Extract domain from SMTPFrom + parts := strings.Split(SMTPFrom, "@") + var domain string + if len(parts) > 1 { + domain = parts[1] + } + // Generate a unique Message-ID + buf := make([]byte, 16) + _, err := rand.Read(buf) + if err != nil { + return err + } + messageId := fmt.Sprintf("<%x@%s>", buf, domain) + mail := []byte(fmt.Sprintf("To: %s\r\n"+ "From: %s<%s>\r\n"+ "Subject: %s\r\n"+ + "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, content)) + receiver, SystemName, SMTPFrom, encodedSubject, messageId, content)) + auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") - var err error + if SMTPPort == 465 { tlsConfig := &tls.Config{ InsecureSkipVerify: true, From 923e24534b4626f667479da4efc9a2867cdeff5a Mon Sep 17 00:00:00 2001 From: Tillman Bailee <51190972+YOMIkio@users.noreply.github.com> Date: Fri, 24 Nov 2023 20:56:53 +0800 Subject: [PATCH 03/30] fix: add Date header for email (#742) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复自建邮箱发送错误: INVALID HEADER Missing required header field: "Date" * chore: fix style --------- Co-authored-by: liyujie <29959257@qq.com> Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com> Co-authored-by: JustSong --- common/email.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/email.go b/common/email.go index 7d6963cc..b915f0f9 100644 --- a/common/email.go +++ b/common/email.go @@ -7,6 +7,7 @@ import ( "fmt" "net/smtp" "strings" + "time" ) func SendEmail(subject string, receiver string, content string) error { @@ -33,9 +34,9 @@ func SendEmail(subject string, receiver string, content string) error { "From: %s<%s>\r\n"+ "Subject: %s\r\n"+ "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 + "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, messageId, content)) - + receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") From 3347a44023e8b1313ba640c222fad9d2a9ce755b Mon Sep 17 00:00:00 2001 From: Ian Li Date: Fri, 24 Nov 2023 21:10:18 +0800 Subject: [PATCH 04/30] feat: support Azure's Whisper model (#720) --- controller/relay-audio.go | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 01267fbf..89a311a0 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -5,11 +5,13 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" + "strings" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -95,13 +97,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + baseURL = c.GetString("base_url") + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + } + requestBody := c.Request.Body req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + req.Header.Set("api-key", apiKey) + req.ContentLength = c.Request.ContentLength + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) From b4e43d97fd11d7f9cbe18e5c1239de8b6ab5b6a5 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 24 Nov 2023 21:21:03 +0800 Subject: [PATCH 05/30] docs: add pr template --- pull_request_template.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 pull_request_template.md diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 00000000..bbcd969c --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1,3 @@ +close #issue_number + +我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file From b273464e777632bd45c4502df4fe12e6fdd264f2 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 24 Nov 2023 21:23:16 +0800 Subject: [PATCH 06/30] docs: update readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 20c81361..7e6a7b38 100644 --- a/README.md +++ b/README.md @@ -51,15 +51,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 赞赏支持

-> **Note** +> [!NOTE] > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 -> **Warning** +> [!WARNING] > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 -> **Warning** +> [!WARNING] > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! ## 功能 From 9889377f0e9260e852fb121d886ef3d9517ff8f9 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 24 Nov 2023 21:39:44 +0800 Subject: [PATCH 07/30] feat: support claude-2.x (close #736) --- common/model-ratio.go | 2 ++ controller/model.go | 18 ++++++++++++++++++ controller/relay-claude.go | 4 +++- web/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 74c74a90..ccbc05dd 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -76,6 +76,8 @@ var ModelRatio = map[string]float64{ "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens + "claude-2.0": 5.51, // $11.02 / 1M tokens + "claude-2.1": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 59ea22e8..8f79524d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,24 @@ func init() { Root: "claude-2", Parent: nil, }, + { + Id: "claude-2.1", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.1", + Parent: nil, + }, + { + Id: "claude-2.0", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.0", + Parent: nil, + }, { Id: "ERNIE-Bot", Object: "model", diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 1f4a3e7b..1b72b47d 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { } else if message.Role == "assistant" { prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) } else if message.Role == "system" { - prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) + if prompt == "" { + prompt = message.StringContent() + } } } prompt += "\n\nAssistant:" diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 654a5d51..bc3886a0 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -60,7 +60,7 @@ const EditChannel = () => { let localModels = []; switch (value) { case 14: - localModels = ['claude-instant-1', 'claude-2']; + localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; break; case 11: localModels = ['PaLM-2']; From 0e73418cdfef809fec7c8a2b6bb632a3b207eb88 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 26 Nov 2023 12:05:16 +0800 Subject: [PATCH 08/30] fix: fix log recording & error handling for relay-audio --- controller/relay-audio.go | 81 ++++++++++++++++++++++----------------- controller/relay-utils.go | 17 +++++--- 2 files changed, 57 insertions(+), 41 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 89a311a0..5b8898a7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -39,41 +39,40 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } - preConsumedTokens := common.PreConsumedQuota modelRatio := common.GetModelRatio(audioModel) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) + var quota int + var preConsumedQuota int + switch relayMode { + case RelayModeAudioSpeech: + preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) + quota = preConsumedQuota + default: + preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) + } userQuota, err := model.CacheGetUserQuota(userId) if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } - quota := 0 // Check if user quota is enough - if relayMode == RelayModeAudioSpeech { - quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) - if quota > userQuota { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - } else { - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } @@ -141,11 +140,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode == RelayModeAudioSpeech { - defer func(ctx context.Context) { - go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) - } else { + if relayMode != RelayModeAudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -159,13 +154,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - defer func(ctx context.Context) { - quota := countTokenText(whisperResponse.Text, audioModel) - quotaDelta := quota - preConsumedQuota - go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) + quota = countTokenText(whisperResponse.Text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } + if resp.StatusCode != http.StatusOK { + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + return relayErrorHandler(resp) + } + quotaDelta := quota - preConsumedQuota + defer func(ctx context.Context) { + go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index c7cd4766..391f28b4 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -195,8 +195,9 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } -func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - err := model.PostConsumeTokenQuota(tokenId, quota) +func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -204,10 +205,14 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c if err != nil { common.SysError("error update user quota cache: " + err.Error()) } - if quota != 0 { + // totalQuota is total quota consumed + if totalQuota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) + model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } } From b7570d5c772f5e4ba8aac4441f8ab9ec86b7aa55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ShinChven=20=E2=9C=A8?= Date: Sun, 3 Dec 2023 17:34:59 +0800 Subject: [PATCH 09/30] feat: support dalle for Azure (#754) * feat: Add Message-ID to email headers to comply with RFC 5322 - Extract domain from SMTPFrom - Generate a unique Message-ID - Add Message-ID to email headers * chore: check slice length * feat: Add Azure compatibility for relayImageHelper - Handle Azure channel requestURL compatibility - Set api-key header for Azure channel authentication - Handle Azure channel request body fixes: https://github.com/songquanpeng/one-api/issues/751 * refactor: update implementation --------- Co-authored-by: JustSong --- controller/relay-audio.go | 7 +------ controller/relay-image.go | 18 ++++++++++++++++-- controller/relay-text.go | 6 +----- controller/relay-utils.go | 10 +++++++++- 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 5b8898a7..9e78dadc 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -98,12 +98,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - baseURL = c.GetString("base_url") + apiVersion := GetAPIVersion(c) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } diff --git a/controller/relay-image.go b/controller/relay-image.go index 0ff18309..b3248fcc 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -10,6 +10,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" "github.com/gin-gonic/gin" ) @@ -101,8 +102,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + apiVersion := GetAPIVersion(c) + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) + } + var requestBody io.Reader - if isModelMapped { + if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) @@ -127,7 +135,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + token := c.Request.Header.Get("Authorization") + if channelType == common.ChannelTypeAzure { // Azure authentication + token = strings.TrimPrefix(token, "Bearer ") + req.Header.Set("api-key", token) + } else { + req.Header.Set("Authorization", token) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) diff --git a/controller/relay-text.go b/controller/relay-text.go index dd9e7153..a3e233d3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -129,11 +129,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { case APITypeOpenAI: if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } + apiVersion := GetAPIVersion(c) requestURL := strings.Split(requestURL, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) baseURL = c.GetString("base_url") diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 391f28b4..839d6ae5 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -191,7 +191,6 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } - return fullRequestURL } @@ -216,3 +215,12 @@ func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } } + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + return apiVersion +} From 8f5b83562b1fc7831149a0cc33f49f044f011c42 Mon Sep 17 00:00:00 2001 From: Zhengyi Dong <95915069+simulacraliasing@users.noreply.github.com> Date: Sun, 3 Dec 2023 17:43:30 +0800 Subject: [PATCH 10/30] fix: fix "invalidPayload" error when request Azure dall-e-3 api without optional parameter (#764) * fix: based on #754 add 'omitempty' in ImageRequest to fit official api reference for relay * Revert "fix: based on #754 add 'omitempty' in ImageRequest to fit official api reference for relay" This reverts commit b526006ce07839f29b2d5b8c6a028147401ac295. * fix: add missing omitempty --------- Co-authored-by: JustSong --- controller/relay.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index f91ba6da..58ee8381 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -133,12 +133,12 @@ type TextRequest struct { type ImageRequest struct { Model string `json:"model"` Prompt string `json:"prompt" binding:"required"` - N int `json:"n"` - Size string `json:"size"` - Quality string `json:"quality"` - ResponseFormat string `json:"response_format"` - Style string `json:"style"` - User string `json:"user"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` } type WhisperResponse struct { From a3f80a3392a658ea09086f2c2f73f5f1701e9341 Mon Sep 17 00:00:00 2001 From: Tillman Bailee <51190972+YOMIkio@users.noreply.github.com> Date: Sun, 3 Dec 2023 20:10:57 +0800 Subject: [PATCH 11/30] feat: enable channel when test succeed (#771) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加功能: 渠道 - 测试所有通道; 设置 - 运营设置 - 监控设置 - 成功时自动启用通道 * refactor: update implementation --------- Co-authored-by: liyujie <29959257@qq.com> Co-authored-by: JustSong --- common/constants.go | 1 + controller/channel-test.go | 36 +++++++++++++++++++------- controller/relay-utils.go | 13 ++++++++++ i18n/en.json | 2 ++ model/option.go | 3 +++ web/src/components/ChannelsTable.js | 2 +- web/src/components/OperationSetting.js | 7 +++++ 7 files changed, 53 insertions(+), 11 deletions(-) diff --git a/common/constants.go b/common/constants.go index c7d3f222..f6860f67 100644 --- a/common/constants.go +++ b/common/constants.go @@ -78,6 +78,7 @@ var QuotaForInviter = 0 var QuotaForInvitee = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false +var AutomaticEnableChannelEnabled = false var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 var ApproximateTokenEnabled = false diff --git a/controller/channel-test.go b/controller/channel-test.go index 1b0b745a..bba9a657 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -81,6 +81,9 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil } if response.Usage.CompletionTokens == 0 { + if response.Error.Message == "" { + response.Error.Message = "补全 tokens 非预期返回 0" + } return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error } return nil, nil @@ -142,20 +145,32 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { +func notifyRootUser(subject string, content string) { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() } - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) err := common.SendEmail(subject, common.RootUserEmail, content) if err != nil { common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } +// disable & notify +func disableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +// enable & notify +func enableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + notifyRootUser(subject, content) +} + func testAllChannels(notify bool) error { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() @@ -178,20 +193,21 @@ func testAllChannels(notify bool) error { } go func() { for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { - continue - } + isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() err, openaiErr := testChannel(channel, *testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() - if milliseconds > disableThreshold { + if isChannelEnabled && milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if shouldDisableChannel(openaiErr, -1) { + if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { disableChannel(channel.Id, channel.Name, err.Error()) } + if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { + enableChannel(channel.Id, channel.Name) + } channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 839d6ae5..38408c7f 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -145,6 +145,19 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool { return false } +func shouldEnableChannel(err error, openAIErr *OpenAIError) bool { + if !common.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} + func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") diff --git a/i18n/en.json b/i18n/en.json index 9b2ca4c8..b0deb83a 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -119,6 +119,7 @@ " 年 ": " y ", "未测试": "Not tested", "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", + "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", @@ -139,6 +140,7 @@ "启用": "Enable", "编辑": "Edit", "添加新的渠道": "Add a new channel", + "测试所有通道": "Test all channels", "测试所有已启用通道": "Test all enabled channels", "更新所有已启用通道余额": "Update the balance of all enabled channels", "刷新": "Refresh", diff --git a/model/option.go b/model/option.go index 4ef4d260..bb8b709c 100644 --- a/model/option.go +++ b/model/option.go @@ -34,6 +34,7 @@ func InitOptionMap() { common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) @@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) { common.EmailDomainRestrictionEnabled = boolValue case "AutomaticDisableChannelEnabled": common.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + common.AutomaticEnableChannelEnabled = boolValue case "ApproximateTokenEnabled": common.ApproximateTokenEnabled = boolValue case "LogConsumeEnabled": diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index d44ea2d7..5d68e2da 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -234,7 +234,7 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/test`); const { success, message } = res.data; if (success) { - showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。'); + showInfo('已成功开始测试所有通道,请刷新页面查看结果。'); } else { showError(message); } diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index bf8b5ffd..3b52bb27 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -16,6 +16,7 @@ const OperationSetting = () => { ChatLink: '', QuotaPerUnit: 0, AutomaticDisableChannelEnabled: '', + AutomaticEnableChannelEnabled: '', ChannelDisableThreshold: 0, LogConsumeEnabled: '', DisplayInCurrencyEnabled: '', @@ -269,6 +270,12 @@ const OperationSetting = () => { name='AutomaticDisableChannelEnabled' onChange={handleInputChange} /> + { submitConfig('monitor').then(); From 01f7b0186fae589e0e5fb83ab0e6d033ba5339aa Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Dec 2023 20:45:11 +0800 Subject: [PATCH 12/30] chore: add routes --- router/relay-router.go | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/router/relay-router.go b/router/relay-router.go index 912f4989..24edc9a9 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -35,12 +35,38 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) relayV1Router.GET("/files/:id", controller.RelayNotImplemented) relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) + relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented) + relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented) relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) relayV1Router.POST("/moderations", controller.Relay) + relayV1Router.POST("/assistants", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.GET("/assistants", controller.RelayNotImplemented) + relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented) + relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented) + relayV1Router.POST("/threads", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id", controller.RelayNotImplemented) + relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented) } } From 379074f7d011df8fcbfe3cd6bdfbca8967d3ab91 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 10 Dec 2023 17:22:52 +0800 Subject: [PATCH 13/30] feat: support plugin for ali channel (close #797) --- controller/relay-text.go | 3 +++ middleware/distributor.go | 2 ++ web/src/pages/Channel/EditChannel.js | 14 ++++++++++++++ 3 files changed, 19 insertions(+) diff --git a/controller/relay-text.go b/controller/relay-text.go index a3e233d3..a69c7f8b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -360,6 +360,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } case APITypeTencent: req.Header.Set("Authorization", apiKey) case APITypePaLM: diff --git a/middleware/distributor.go b/middleware/distributor.go index c4ddc3a0..8be986c9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -89,6 +89,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } c.Next() } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index bc3886a0..62e8a155 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -343,6 +343,20 @@ const EditChannel = () => { ) } + { + inputs.type === 17 && ( + + + + ) + } Date: Sun, 10 Dec 2023 18:39:14 +0800 Subject: [PATCH 14/30] feat: refactor response parsing logic to support multiple formats (#782) * feat: Refactor response parsing logic to support multiple formats The parsing logic for responses in relay.go and relay-audio.go was refactored to support multiple response formats - 'json', 'text', 'srt', 'verbose_json', and 'vtt'. The existing `WhisperResponse` struct was renamed to `WhisperJsonResponse` and a new struct `WhisperVerboseJsonResponse` was added to support the 'verbose_json' format. Additional parsing functions were added to extract text from these new response types. This change was necessary to make the parsing logic more flexible and extendable for different types of responses. * chore: update name --------- Co-authored-by: JustSong --- controller/relay-audio.go | 85 ++++++++++++++++++++++++++++++++++++--- controller/relay.go | 23 ++++++++++- 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 9e78dadc..2247f4c7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -1,6 +1,7 @@ package controller import ( + "bufio" "bytes" "context" "encoding/json" @@ -102,7 +103,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } - requestBody := c.Request.Body + requestBody := &bytes.Buffer{} + _, err = io.Copy(requestBody, c.Request.Body) + if err != nil { + return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -144,12 +151,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var whisperResponse WhisperResponse - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + + var openAIErr TextResponse + if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + if openAIErr.Error.Message != "" { + return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + } } - quota = countTokenText(whisperResponse.Text, audioModel) + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + default: + return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + } + if err != nil { + return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + } + quota = countTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -187,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } return nil } + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse WhisperJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/controller/relay.go b/controller/relay.go index 58ee8381..0e660a68 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -141,10 +141,31 @@ type ImageRequest struct { User string `json:"user,omitempty"` } -type WhisperResponse struct { +type WhisperJSONResponse struct { Text string `json:"text,omitempty"` } +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + type TextToSpeechRequest struct { Model string `json:"model" binding:"required"` Input string `json:"input" binding:"required"` From 4c5feee0b66e0d42091ef9fa4bd1f5f3f78ec38f Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sun, 10 Dec 2023 19:39:46 +0800 Subject: [PATCH 15/30] feat: add image counter for gpt-4 vision (#795) --- common/image/image.go | 47 +++++++++++ common/image/image_test.go | 154 +++++++++++++++++++++++++++++++++++++ controller/relay-text.go | 4 + controller/relay-utils.go | 105 ++++++++++++++++++++++++- go.mod | 6 +- go.sum | 6 +- 6 files changed, 315 insertions(+), 7 deletions(-) create mode 100644 common/image/image.go create mode 100644 common/image/image_test.go diff --git a/common/image/image.go b/common/image/image.go new file mode 100644 index 00000000..cbb656ad --- /dev/null +++ b/common/image/image.go @@ -0,0 +1,47 @@ +package image + +import ( + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "net/http" + "regexp" + "strings" + + _ "golang.org/x/image/webp" +) + +func GetImageSizeFromUrl(url string) (width int, height int, err error) { + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + img, _, err := image.DecodeConfig(resp.Body) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +var ( + reg = regexp.MustCompile(`data:image/([^;]+);base64,`) +) + +func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { + encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") + base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) + img, _, err := image.DecodeConfig(base64) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +func GetImageSize(image string) (width int, height int, err error) { + if strings.HasPrefix(image, "data:image/") { + return GetImageSizeFromBase64(image) + } + return GetImageSizeFromUrl(image) +} diff --git a/common/image/image_test.go b/common/image/image_test.go new file mode 100644 index 00000000..366eda6e --- /dev/null +++ b/common/image/image_test.go @@ -0,0 +1,154 @@ +package image_test + +import ( + "encoding/base64" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "strconv" + "strings" + "testing" + + img "one-api/common/image" + + "github.com/stretchr/testify/assert" + _ "golang.org/x/image/webp" +) + +type CountingReader struct { + reader io.Reader + BytesRead int +} + +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.BytesRead += n + return n, err +} + +var ( + cases = []struct { + url string + format string + width int + height int + }{ + {"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, + {"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, + {"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, + {"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, + {"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, + } +) + +func TestDecode(t *testing.T) { + // Bytes read: varies sometimes + // jpeg: 1063892 + // png: 294462 + // webp: 99529 + // gif: 956153 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 4096 + // png: 4096 + // webp: 4096 + // gif: 4096 + // jpeg#01: 4096 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestBase64(t *testing.T) { + // Bytes read: + // jpeg: 1063892 + // png: 294462 + // webp: 99072 + // gif: 953856 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 1536 + // png: 768 + // webp: 768 + // gif: 1536 + // jpeg#01: 3840 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestGetImageSize(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + width, height, err := img.GetImageSize(c.url) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a69c7f8b..c3d54059 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -410,6 +410,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { + if promptTokens != textResponse.PromptTokens { + common.SysError(fmt.Sprintf("prompt tokens not match, expected %d, actual %d", promptTokens, textResponse.PromptTokens)) + } + quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 38408c7f..9deca75a 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -3,10 +3,13 @@ package controller import ( "context" "encoding/json" + "errors" "fmt" "io" + "math" "net/http" "one-api/common" + "one-api/common/image" "one-api/model" "strconv" "strings" @@ -87,7 +90,33 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.StringContent()) + switch v := message.Content.(type) { + case string: + tokenNum += getTokenNum(tokenEncoder, v) + case []any: + for _, it := range v { + m := it.(map[string]any) + switch m["type"] { + case "text": + tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) + case "image_url": + imageUrl, ok := m["image_url"].(map[string]any) + if ok { + url := imageUrl["url"].(string) + detail := "" + if imageUrl["detail"] != nil { + detail = imageUrl["detail"].(string) + } + imageTokens, err := countImageTokens(url, detail) + if err != nil { + common.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + } + } + } + } tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName @@ -98,13 +127,81 @@ func countTokenMessages(messages []Message, model string) int { return tokenNum } +const ( + lowDetailCost = 85 + highDetailCostPerTile = 170 + additionalCost = 85 +) + +// https://platform.openai.com/docs/guides/vision/calculating-costs +// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb +func countImageTokens(url string, detail string) (_ int, err error) { + var fetchSize = true + var width, height int + // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding + // detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. + // According to the official guide, "low" disable the high-res model, + // and only receive low-res 512px x 512px version of the image, indicating + // that image is treated as low-res when size is smaller than 512px x 512px, + // then we can assume that image size larger than 512px x 512px is treated + // as high-res. Then we have the following logic: + // if detail == "" || detail == "auto" { + // width, height, err = image.GetImageSize(url) + // if err != nil { + // return 0, err + // } + // fetchSize = false + // // not sure if this is correct + // if width > 512 || height > 512 { + // detail = "high" + // } else { + // detail = "low" + // } + // } + + // However, in my test, it seems to be always the same as "high". + // The following image, which is 125x50, is still treated as high-res, taken + // 255 tokens in the response of non-stream chat completion api. + // https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg + if detail == "" || detail == "auto" { + // assume by test, not sure if this is correct + detail = "high" + } + switch detail { + case "low": + return lowDetailCost, nil + case "high": + if fetchSize { + width, height, err = image.GetImageSize(url) + if err != nil { + return 0, err + } + } + if width > 2048 || height > 2048 { // max(width, height) > 2048 + ratio := float64(2048) / math.Max(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + if width > 768 && height > 768 { // min(width, height) > 768 + ratio := float64(768) / math.Min(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) + result := numSquares*highDetailCostPerTile + additionalCost + return result, nil + default: + return 0, errors.New("invalid detail option") + } +} + func countTokenInput(input any, model string) int { - switch input.(type) { + switch v := input.(type) { case string: - return countTokenText(input.(string), model) + return countTokenText(v, model) case []string: text := "" - for _, s := range input.([]string) { + for _, s := range v { text += s } return countTokenText(text, model) diff --git a/go.mod b/go.mod index 10b78d68..1fe5eabc 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,9 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 + github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.14.0 + golang.org/x/image v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 @@ -26,6 +28,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect @@ -50,12 +53,13 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4865bcaa..fb252aa7 100644 --- a/go.sum +++ b/go.sum @@ -152,6 +152,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= @@ -168,8 +170,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 2a70744dbf266f25170ae04ca4ae7fe007ef6018 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 10 Dec 2023 19:53:33 +0800 Subject: [PATCH 16/30] feat: add panic recover middleware --- middleware/recover.go | 26 ++++++++++++++++++++++++++ router/relay-router.go | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 middleware/recover.go diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 00000000..c3a3d748 --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" +) + +func RelayPanicRecover() gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + common.SysError(fmt.Sprintf("panic detected: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "type": "one_api_panic", + }, + }) + c.Abort() + } + }() + c.Next() + } +} diff --git a/router/relay-router.go b/router/relay-router.go index 24edc9a9..56ab9b28 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay) From 366b82128f89a328f096da6951cbafebb6b0060f Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 10 Dec 2023 20:44:37 +0800 Subject: [PATCH 17/30] fix: remove incorrect logging --- controller/relay-text.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index c3d54059..a69c7f8b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -410,10 +410,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { - if promptTokens != textResponse.PromptTokens { - common.SysError(fmt.Sprintf("prompt tokens not match, expected %d, actual %d", promptTokens, textResponse.PromptTokens)) - } - quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens From 5cf23d8698cfebd720a6722e4b89a839c8a804c0 Mon Sep 17 00:00:00 2001 From: David Zhuang Date: Sat, 16 Dec 2023 23:48:32 -0500 Subject: [PATCH 18/30] feat: add Google Gemini Pro support (#826) * fest: Add Google Gemini Pro, fix #810 * fest: Add tooling to Gemini; Add OpenAI-like system prompt to Gemini * refactor: removing unused if statement * fest: Add dummy model message for system message in gemini model * chore: update implementation --------- Co-authored-by: JustSong --- README.en.md | 2 +- README.ja.md | 2 +- README.md | 2 +- common/constants.go | 2 + common/model-ratio.go | 1 + controller/channel-test.go | 2 + controller/model.go | 9 + controller/relay-gemini.go | 281 +++++++++++++++++++++++++ controller/relay-text.go | 49 +++++ middleware/distributor.go | 2 + web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 3 + 12 files changed, 353 insertions(+), 3 deletions(-) create mode 100644 controller/relay-gemini.go diff --git a/README.en.md b/README.en.md index 9345a219..82dceb5b 100644 --- a/README.en.md +++ b/README.en.md @@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use 1. Support for multiple large models: + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude Series Models](https://anthropic.com) - + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) diff --git a/README.ja.md b/README.ja.md index 6faf9bee..089fc2b5 100644 --- a/README.ja.md +++ b/README.ja.md @@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に 1. 複数の大型モデルをサポート: + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) - + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) diff --git a/README.md b/README.md index 7e6a7b38..8a1d6caf 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) - + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) diff --git a/common/constants.go b/common/constants.go index f6860f67..60700ec8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,6 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeGemini = 24 ) var ChannelBaseURLs = []string{ @@ -214,4 +215,5 @@ var ChannelBaseURLs = []string{ "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.cloud.tencent.com", //23 + "", //24 } diff --git a/common/model-ratio.go b/common/model-ratio.go index ccbc05dd..c054fa5f 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -83,6 +83,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index bba9a657..3aaa4897 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai switch channel.Type { case common.ChannelTypePaLM: fallthrough + case common.ChannelTypeGemini: + fallthrough case common.ChannelTypeAnthropic: fallthrough case common.ChannelTypeBaidu: diff --git a/controller/model.go b/controller/model.go index 8f79524d..5c8aebc0 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go new file mode 100644 index 00000000..455e30d8 --- /dev/null +++ b/controller/relay-gemini.go @@ -0,0 +1,281 @@ +package controller + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + + "github.com/gin-gonic/gin" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + if textRequest.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: textRequest.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range textRequest.Messages { + content := GeminiChatContent{ + Role: message.Role, + Parts: []GeminiPart{ + { + Text: message.StringContent(), + }, + }, + } + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + content.Role = "user" + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "ok", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: candidate.Content.Parts[0].Text, + }, + FinishReason: stopFinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text + } + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + responseText += geminiResponse.Candidates[0].Content.Parts[0].Text + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a69c7f8b..211a34b3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,6 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeGemini ) var httpClient *http.Client @@ -118,6 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -177,6 +180,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey + case APITypeGemini: + requestBaseURL := "https://generativelanguage.googleapis.com" + if baseURL != "" { + requestBaseURL = baseURL + } + version := "v1" + if c.GetString("api_version") != "" { + version = c.GetString("api_version") + } + action := "generateContent" + // actually gemini does not support stream, it's fake + //if textRequest.Stream { + // action = "streamGenerateContent" + //} + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -274,6 +295,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeGemini: + geminiChatRequest := requestOpenAI2Gemini(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: zhipuRequest := requestOpenAI2Zhipu(textRequest) jsonStr, err := json.Marshal(zhipuRequest) @@ -367,6 +395,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header + case APITypeGemini: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -527,6 +557,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeGemini: + if textRequest.Stream { + err, responseText := geminiChatStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } case APITypeZhipu: if isStream { err, usage := zhipuStreamHandler(c, resp) diff --git a/middleware/distributor.go b/middleware/distributor.go index 8be986c9..81338130 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) case common.ChannelTypeAli: diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 76407745..264dbefb 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 62e8a155..114e5933 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -83,6 +83,9 @@ const EditChannel = () => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['gemini-pro']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } From 58dee76bf7828718021f924dd741c3754515cd7c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 16:16:18 +0800 Subject: [PATCH 19/30] fix: fix Gemini stream problem --- controller/relay-gemini.go | 81 ++++++++++++++++++++++---------------- controller/relay-text.go | 7 ++-- controller/relay.go | 2 +- 3 files changed, 51 insertions(+), 39 deletions(-) diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 455e30d8..4c2daba9 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -1,11 +1,13 @@ package controller import ( + "bufio" "encoding/json" "fmt" "io" "net/http" "one-api/common" + "strings" "github.com/gin-gonic/gin" ) @@ -180,50 +182,61 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCo func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data } - err = resp.Body.Close() - if err != nil { - common.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var geminiResponse GeminiChatResponse - err = json.Unmarshal(responseBody, &geminiResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - responseText += geminiResponse.Candidates[0].Content.Parts[0].Text - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) stopChan <- true }() setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = dummy.Content + response := ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "gemini-pro", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) diff --git a/controller/relay-text.go b/controller/relay-text.go index 211a34b3..b53b0caa 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -190,10 +190,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { version = c.GetString("api_version") } action := "generateContent" - // actually gemini does not support stream, it's fake - //if textRequest.Stream { - // action = "streamGenerateContent" - //} + if textRequest.Stream { + action = "streamGenerateContent" + } fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") diff --git a/controller/relay.go b/controller/relay.go index 0e660a68..15021997 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -236,7 +236,7 @@ type ChatCompletionsStreamResponseChoice struct { Delta struct { Content string `json:"content"` } `json:"delta"` - FinishReason *string `json:"finish_reason"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { From 7069c49bdf4e4ae39ab4944faaa5fd66b121d3fb Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 18:06:37 +0800 Subject: [PATCH 20/30] fix: fix xunfei panic error (close #820) --- controller/relay-xunfei.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 00ec8981..904e6d14 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -230,7 +230,13 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin case stop = <-stopChan: } } - + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } xunfeiResponse.Payload.Choices.Text[0].Content = content response := responseXunfei2OpenAI(&xunfeiResponse) From 6acb9537a921008813eb691e021aa87f316fb82e Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 18:33:27 +0800 Subject: [PATCH 21/30] fix: try to return a more meaningful error message (close #817) --- controller/relay-utils.go | 57 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 9deca75a..a6a1f0f6 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -263,11 +263,52 @@ func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +type GeneralErrorResponse struct { + Error OpenAIError `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, OpenAIError: OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Message: "", Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), @@ -281,12 +322,20 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr if err != nil { return } - var textResponse TextResponse - err = json.Unmarshal(responseBody, &textResponse) + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) if err != nil { return } - openAIErrorWithStatusCode.OpenAIError = textResponse.Error + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + openAIErrorWithStatusCode.OpenAIError = errResponse.Error + } else { + openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() + } + if openAIErrorWithStatusCode.OpenAIError.Message == "" { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } return } From 66f06e5d6f14828c3522fb59bfca35c612685bef Mon Sep 17 00:00:00 2001 From: Ghostz <137054651+ye4293@users.noreply.github.com> Date: Sun, 17 Dec 2023 18:54:08 +0800 Subject: [PATCH 22/30] feat: reset image num to 1 when not given (#821) * Update relay-image.go * fix: reset image num to 1 when not given --------- Co-authored-by: JustSong --- controller/relay-image.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index b3248fcc..bf916474 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -19,7 +19,6 @@ func isWithinRange(element string, value int) bool { if _, ok := common.DalleGenerationImageAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] max := common.DalleGenerationImageAmounts[element][1] @@ -42,6 +41,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + // Size validation if imageRequest.Size != "" { imageSize = imageRequest.Size From 7d6a169669f1eb2666abb0747964ab1cd9f86f12 Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Sun, 17 Dec 2023 19:17:00 +0800 Subject: [PATCH 23/30] feat: able to set sqlite busy_timeout (#818) * add sqlite busy_timeout=3000 * chore: update impl --------- Co-authored-by: JustSong --- README.md | 1 + common/database.go | 1 + model/main.go | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8a1d6caf..916e4331 100644 --- a/README.md +++ b/README.md @@ -371,6 +371,7 @@ graph LR + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/database.go b/common/database.go index c7e9fd52..76f2cd55 100644 --- a/common/database.go +++ b/common/database.go @@ -4,3 +4,4 @@ var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" +var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/model/main.go b/model/main.go index 08182634..bfd6888b 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -59,7 +60,8 @@ func chooseDB() (*gorm.DB, error) { // Use SQLite common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true - return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ + config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } From 0fe26cc4bd9d2e85b47ac6ae23fc974295c72f52 Mon Sep 17 00:00:00 2001 From: Oliver Lee Date: Sun, 17 Dec 2023 19:43:23 +0800 Subject: [PATCH 24/30] feat: update ali relay implementation (#830) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修改通译千问最新接口:1.删除history参数,改用官方推荐的messages参数 2.整理messages参数顺序,补充必要上下文信息 3.用autogen调试测试通过 * chore: update impl --------- Co-authored-by: JustSong --- common/model-ratio.go | 6 +++-- controller/model.go | 18 +++++++++++++++ controller/relay-ali.go | 33 ++++++++-------------------- web/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index c054fa5f..d1c96d96 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -88,8 +88,10 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens - "qwen-plus": 10, // ¥0.14 / 1k tokens + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-plus": 1.4286, // ¥0.02 / 1k tokens + "qwen-max": 1.4286, // ¥0.02 / 1k tokens + "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 5c8aebc0..9ae40f5c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -486,6 +486,24 @@ func init() { Root: "qwen-plus", Parent: nil, }, + { + Id: "qwen-max", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max", + Parent: nil, + }, + { + Id: "qwen-max-longcontext", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max-longcontext", + Parent: nil, + }, { Id: "text-embedding-v1", Object: "model", diff --git a/controller/relay-ali.go b/controller/relay-ali.go index b41ca327..65626f6a 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -13,13 +13,13 @@ import ( // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` + Content string `json:"content"` + Role string `json:"role"` } type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` + //Prompt string `json:"prompt"` + Messages []AliMessage `json:"messages"` } type AliParameters struct { @@ -83,32 +83,17 @@ type AliChatResponse struct { func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) - prompt := "" for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: "Okay", - }) - continue - } else { - if i == len(request.Messages)-1 { - prompt = message.StringContent() - break - } - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: request.Messages[i+1].StringContent(), - }) - i++ - } + messages = append(messages, AliMessage{ + Content: message.StringContent(), + Role: strings.ToLower(message.Role), + }) } return &AliChatRequest{ Model: request.Model, Input: AliInput{ - Prompt: prompt, - History: messages, + Messages: messages, }, //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's // TopP: request.TopP, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 114e5933..364da69d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -69,7 +69,7 @@ const EditChannel = () => { localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; break; case 17: - localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; + localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; break; case 16: localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; From bc6769826bb827ca323beac6118bfd15e829fee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ShinChven=20=E2=9C=A8?= Date: Sun, 17 Dec 2023 19:49:08 +0800 Subject: [PATCH 25/30] feat: add condition to validate n value for non-Azure channels (#775) - Add a condition to validate the n value only for non-Azure channels, ensuring it falls within the acceptable range. - Fix Azure compatibility --- controller/relay-image.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index bf916474..7e1fed39 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -82,7 +82,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode // Number of generated images validation if isWithinRange(imageModel, imageRequest.N) == false { - return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + // channel not azure + if channelType != common.ChannelTypeAzure { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } } // map model name @@ -105,7 +108,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := GetAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview From af378c59afe3510ebc28b7f081a7a8ccfa8de891 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 22:19:16 +0800 Subject: [PATCH 26/30] docs: update readme --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 916e4331..1a967ace 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [OpenAI-SB](https://openai-sb.com) + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) + [x] [API2D](https://api2d.com/r/197971) - + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) - + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) + [x] 自定义渠道:例如各种未收录的第三方代理服务 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 From 461f5dab561ff09d1bada624058c943daa7b9f99 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 22:25:03 +0800 Subject: [PATCH 27/30] docs: update readme --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 1a967ace..ff9e0bc0 100644 --- a/README.md +++ b/README.md @@ -73,11 +73,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [360 智脑](https://ai.360.cn) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) -2. 支持配置镜像以及众多第三方代理服务: - + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) - + [x] [API2D](https://api2d.com/r/197971) - + [x] 自定义渠道:例如各种未收录的第三方代理服务 +2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 From 97030e27f88ac27a9ab4bcc2484f7d8e83d29d04 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 23:30:45 +0800 Subject: [PATCH 28/30] fix: fix gemini panic (close #833) --- controller/relay-gemini.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 4c2daba9..2458458e 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -114,7 +114,7 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { Role: "model", Parts: []GeminiPart{ { - Text: "ok", + Text: "Okay", }, }, }) @@ -130,6 +130,16 @@ type GeminiChatResponse struct { PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` } +func (g *GeminiChatResponse) GetResponseText() string { + if g == nil { + return "" + } + if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { + return g.Candidates[0].Content.Parts[0].Text + } + return "" +} + type GeminiChatCandidate struct { Content GeminiChatContent `json:"content"` FinishReason string `json:"finishReason"` @@ -158,10 +168,13 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse Index: i, Message: Message{ Role: "assistant", - Content: candidate.Content.Parts[0].Text, + Content: "", }, FinishReason: stopFinishReason, } + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[0].Text + } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse @@ -169,9 +182,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { var choice ChatCompletionsStreamResponseChoice - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text - } + choice.Delta.Content = geminiResponse.GetResponseText() choice.FinishReason = &stopFinishReason var response ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" @@ -276,7 +287,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + completionTokens := countTokenText(geminiResponse.GetResponseText(), model) usage := Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, From 67c64e71c800a76b2fa7a5b6eb51e89a14479b0c Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 20 Dec 2023 21:45:33 +0800 Subject: [PATCH 29/30] fix: fix max_tokens check --- controller/relay-text.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/controller/relay-text.go b/controller/relay-text.go index b53b0caa..c49a2abe 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -58,6 +58,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } + if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { + return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) + } if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } From b7fcb319da622a6de340b34e9255efd5fae6cf19 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 20 Dec 2023 22:50:50 +0800 Subject: [PATCH 30/30] chore: check if SESSION_SECRET equals to random_string --- common/init.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/common/init.go b/common/init.go index 1e9c85ce..12df5f51 100644 --- a/common/init.go +++ b/common/init.go @@ -36,7 +36,11 @@ func init() { } if os.Getenv("SESSION_SECRET") != "" { - SessionSecret = os.Getenv("SESSION_SECRET") + if os.Getenv("SESSION_SECRET") == "random_string" { + SysError("SESSION_SECRET is set to an example value, please change it to a random string.") + } else { + SessionSecret = os.Getenv("SESSION_SECRET") + } } if os.Getenv("SQLITE_PATH") != "" { SQLitePath = os.Getenv("SQLITE_PATH")