From d85e356b6e6a7b22750f1aa93c3024862e82c58f Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Fri, 24 Nov 2023 20:42:29 +0800 Subject: [PATCH 1/8] refactor: remove consumeQuota related logic (#738) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: 删除relay-text中的consumeQuota变量 该变量始终为true,可以删除 * chore: remove useless code --------- Co-authored-by: JustSong --- controller/relay-image.go | 71 +++++++++++++++++-------------------- controller/relay-openai.go | 45 ++++++++++++------------ controller/relay-text.go | 72 ++++++++++++++++++-------------------- middleware/auth.go | 6 ---- 4 files changed, 88 insertions(+), 106 deletions(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index 1d1b71ba..0ff18309 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var imageRequest ImageRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } // Size validation @@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*imageCostRatio*1000) * imageRequest.N - if consumeQuota && userQuota-quota < 0 { + if userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } @@ -151,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var textResponse ImageResponse defer func(ctx context.Context) { - if consumeQuota { - err := model.PostConsumeTokenQuota(tokenId, quota) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) + responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) diff --git a/controller/relay-openai.go b/controller/relay-openai.go index dcd20115..37867843 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O return nil, responseText } -func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { var textResponse TextResponse - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. @@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, err := io.Copy(c.Writer, resp.Body) + _, err = io.Copy(c.Writer, resp.Body) if err != nil { return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } diff --git a/controller/relay-text.go b/controller/relay-text.go index 018c8d8a..dd9e7153 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var textRequest GeneralOpenAIRequest - if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &textRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" @@ -235,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { preConsumedQuota = 0 common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } - if consumeQuota && preConsumedQuota > 0 { + if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) @@ -414,37 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { - if consumeQuota { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } + quota := 0 + completionRatio := common.GetCompletionRatio(textRequest.Model) + promptTokens = textResponse.Usage.PromptTokens + completionTokens = textResponse.Usage.CompletionTokens + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) + } + }() }(c.Request.Context()) switch apiType { @@ -458,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) return nil } else { - err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) + err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } diff --git a/middleware/auth.go b/middleware/auth.go index b0803612..ad7e64b7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) From b4d67ca6144eff90a2f8bf8f7f1b262af67d2e0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ShinChven=20=E2=9C=A8?= Date: Fri, 24 Nov 2023 20:52:59 +0800 Subject: [PATCH 2/8] fix: add Message-ID header for email (#732) * feat: Add Message-ID to email headers to comply with RFC 5322 - Extract domain from SMTPFrom - Generate a unique Message-ID - Add Message-ID to email headers * chore: check slice length --------- Co-authored-by: JustSong --- common/email.go | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/common/email.go b/common/email.go index 74f4cccd..7d6963cc 100644 --- a/common/email.go +++ b/common/email.go @@ -1,6 +1,7 @@ package common import ( + "crypto/rand" "crypto/tls" "encoding/base64" "fmt" @@ -13,15 +14,32 @@ func SendEmail(subject string, receiver string, content string) error { SMTPFrom = SMTPAccount } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) + + // Extract domain from SMTPFrom + parts := strings.Split(SMTPFrom, "@") + var domain string + if len(parts) > 1 { + domain = parts[1] + } + // Generate a unique Message-ID + buf := make([]byte, 16) + _, err := rand.Read(buf) + if err != nil { + return err + } + messageId := fmt.Sprintf("<%x@%s>", buf, domain) + mail := []byte(fmt.Sprintf("To: %s\r\n"+ "From: %s<%s>\r\n"+ "Subject: %s\r\n"+ + "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, content)) + receiver, SystemName, SMTPFrom, encodedSubject, messageId, content)) + auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") - var err error + if SMTPPort == 465 { tlsConfig := &tls.Config{ InsecureSkipVerify: true, From 923e24534b4626f667479da4efc9a2867cdeff5a Mon Sep 17 00:00:00 2001 From: Tillman Bailee <51190972+YOMIkio@users.noreply.github.com> Date: Fri, 24 Nov 2023 20:56:53 +0800 Subject: [PATCH 3/8] fix: add Date header for email (#742) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复自建邮箱发送错误: INVALID HEADER Missing required header field: "Date" * chore: fix style --------- Co-authored-by: liyujie <29959257@qq.com> Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com> Co-authored-by: JustSong --- common/email.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/email.go b/common/email.go index 7d6963cc..b915f0f9 100644 --- a/common/email.go +++ b/common/email.go @@ -7,6 +7,7 @@ import ( "fmt" "net/smtp" "strings" + "time" ) func SendEmail(subject string, receiver string, content string) error { @@ -33,9 +34,9 @@ func SendEmail(subject string, receiver string, content string) error { "From: %s<%s>\r\n"+ "Subject: %s\r\n"+ "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 + "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, messageId, content)) - + receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") From 3347a44023e8b1313ba640c222fad9d2a9ce755b Mon Sep 17 00:00:00 2001 From: Ian Li Date: Fri, 24 Nov 2023 21:10:18 +0800 Subject: [PATCH 4/8] feat: support Azure's Whisper model (#720) --- controller/relay-audio.go | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 01267fbf..89a311a0 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -5,11 +5,13 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" + "strings" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -95,13 +97,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + baseURL = c.GetString("base_url") + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + } + requestBody := c.Request.Body req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + req.Header.Set("api-key", apiKey) + req.ContentLength = c.Request.ContentLength + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) From b4e43d97fd11d7f9cbe18e5c1239de8b6ab5b6a5 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 24 Nov 2023 21:21:03 +0800 Subject: [PATCH 5/8] docs: add pr template --- pull_request_template.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 pull_request_template.md diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 00000000..bbcd969c --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1,3 @@ +close #issue_number + +我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file From b273464e777632bd45c4502df4fe12e6fdd264f2 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 24 Nov 2023 21:23:16 +0800 Subject: [PATCH 6/8] docs: update readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 20c81361..7e6a7b38 100644 --- a/README.md +++ b/README.md @@ -51,15 +51,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 赞赏支持

-> **Note** +> [!NOTE] > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 -> **Warning** +> [!WARNING] > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 -> **Warning** +> [!WARNING] > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! ## 功能 From 9889377f0e9260e852fb121d886ef3d9517ff8f9 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 24 Nov 2023 21:39:44 +0800 Subject: [PATCH 7/8] feat: support claude-2.x (close #736) --- common/model-ratio.go | 2 ++ controller/model.go | 18 ++++++++++++++++++ controller/relay-claude.go | 4 +++- web/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 74c74a90..ccbc05dd 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -76,6 +76,8 @@ var ModelRatio = map[string]float64{ "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens + "claude-2.0": 5.51, // $11.02 / 1M tokens + "claude-2.1": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 59ea22e8..8f79524d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,24 @@ func init() { Root: "claude-2", Parent: nil, }, + { + Id: "claude-2.1", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.1", + Parent: nil, + }, + { + Id: "claude-2.0", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.0", + Parent: nil, + }, { Id: "ERNIE-Bot", Object: "model", diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 1f4a3e7b..1b72b47d 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { } else if message.Role == "assistant" { prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) } else if message.Role == "system" { - prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) + if prompt == "" { + prompt = message.StringContent() + } } } prompt += "\n\nAssistant:" diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 654a5d51..bc3886a0 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -60,7 +60,7 @@ const EditChannel = () => { let localModels = []; switch (value) { case 14: - localModels = ['claude-instant-1', 'claude-2']; + localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; break; case 11: localModels = ['PaLM-2']; From 0e73418cdfef809fec7c8a2b6bb632a3b207eb88 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 26 Nov 2023 12:05:16 +0800 Subject: [PATCH 8/8] fix: fix log recording & error handling for relay-audio --- controller/relay-audio.go | 81 ++++++++++++++++++++++----------------- controller/relay-utils.go | 17 +++++--- 2 files changed, 57 insertions(+), 41 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 89a311a0..5b8898a7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -39,41 +39,40 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } - preConsumedTokens := common.PreConsumedQuota modelRatio := common.GetModelRatio(audioModel) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) + var quota int + var preConsumedQuota int + switch relayMode { + case RelayModeAudioSpeech: + preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) + quota = preConsumedQuota + default: + preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) + } userQuota, err := model.CacheGetUserQuota(userId) if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } - quota := 0 // Check if user quota is enough - if relayMode == RelayModeAudioSpeech { - quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) - if quota > userQuota { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - } else { - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } @@ -141,11 +140,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode == RelayModeAudioSpeech { - defer func(ctx context.Context) { - go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) - } else { + if relayMode != RelayModeAudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -159,13 +154,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - defer func(ctx context.Context) { - quota := countTokenText(whisperResponse.Text, audioModel) - quotaDelta := quota - preConsumedQuota - go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) + quota = countTokenText(whisperResponse.Text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } + if resp.StatusCode != http.StatusOK { + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + return relayErrorHandler(resp) + } + quotaDelta := quota - preConsumedQuota + defer func(ctx context.Context) { + go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index c7cd4766..391f28b4 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -195,8 +195,9 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } -func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - err := model.PostConsumeTokenQuota(tokenId, quota) +func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -204,10 +205,14 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c if err != nil { common.SysError("error update user quota cache: " + err.Error()) } - if quota != 0 { + // totalQuota is total quota consumed + if totalQuota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) + model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } }