diff --git a/README.md b/README.md index c6aecbb0..38ded40e 100644 --- a/README.md +++ b/README.md @@ -63,15 +63,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 赞赏支持

-> **Note** +> [!NOTE] > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 -> **Warning** +> [!WARNING] > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 -> **Warning** +> [!WARNING] > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! ## 功能 @@ -104,14 +104,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 12. 支持**用户邀请奖励**。 13. 支持以美元为单位显示额度。 14. 支持发布公告,设置充值链接,设置新用户初始额度。 -15. 支持模型映射,重定向用户的请求模型。 +15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。 16. 支持失败自动重试。 17. 支持绘图接口。 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 19. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -20. 支持通过系统访问Key访问管理 API。 +20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 21. 支持 Cloudflare Turnstile 用户校验。 22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 diff --git a/common/email.go b/common/email.go index 74f4cccd..b915f0f9 100644 --- a/common/email.go +++ b/common/email.go @@ -1,11 +1,13 @@ package common import ( + "crypto/rand" "crypto/tls" "encoding/base64" "fmt" "net/smtp" "strings" + "time" ) func SendEmail(subject string, receiver string, content string) error { @@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error { SMTPFrom = SMTPAccount } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) + + // Extract domain from SMTPFrom + parts := strings.Split(SMTPFrom, "@") + var domain string + if len(parts) > 1 { + domain = parts[1] + } + // Generate a unique Message-ID + buf := make([]byte, 16) + _, err := rand.Read(buf) + if err != nil { + return err + } + messageId := fmt.Sprintf("<%x@%s>", buf, domain) + mail := []byte(fmt.Sprintf("To: %s\r\n"+ "From: %s<%s>\r\n"+ "Subject: %s\r\n"+ + "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 + "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, content)) + receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") - var err error + if SMTPPort == 465 { tlsConfig := &tls.Config{ InsecureSkipVerify: true, diff --git a/common/gin.go b/common/gin.go index ffa1e218..f5012688 100644 --- a/common/gin.go +++ b/common/gin.go @@ -5,6 +5,7 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "io" + "strings" ) func UnmarshalBodyReusable(c *gin.Context, v any) error { @@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } - err = json.Unmarshal(requestBody, &v) + contentType := c.Request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + err = json.Unmarshal(requestBody, &v) + } else { + // skip for now + // TODO: someday non json request have variant model, we will need to implementation this + } if err != nil { return err } diff --git a/common/model-ratio.go b/common/model-ratio.go index 199fb5aa..2d441f3a 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -76,6 +76,8 @@ var ModelRatio = map[string]float64{ "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, "claude-2": 5.51, + "claude-2.0": 5.51, + "claude-2.1": 5.51, "ERNIE-Bot": 0.8572, "ERNIE-Bot-turbo": 0.5715, "ERNIE-Bot-4": 8.572, diff --git a/controller/channel-test.go b/controller/channel-test.go index b9a6c980..15924c66 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -43,14 +43,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) + requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) } else { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { requestURL = baseURL } + requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) } - jsonData, err := json.Marshal(request) if err != nil { return err, nil diff --git a/controller/model.go b/controller/model.go index 59ea22e8..8f79524d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,24 @@ func init() { Root: "claude-2", Parent: nil, }, + { + Id: "claude-2.1", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.1", + Parent: nil, + }, + { + Id: "claude-2.0", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.0", + Parent: nil, + }, { Id: "ERNIE-Bot", Object: "model", diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go index d0159ce8..543954f7 100644 --- a/controller/relay-aiproxy.go +++ b/controller/relay-aiproxy.go @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { query := "" if len(request.Messages) != 0 { - query = request.Messages[len(request.Messages)-1].Content + query = request.Messages[len(request.Messages)-1].StringContent() } return &AIProxyLibraryRequest{ Model: request.Model, diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 50dc743c..b41ca327 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { message := request.Messages[i] if message.Role == "system" { messages = append(messages, AliMessage{ - User: message.Content, + User: message.StringContent(), Bot: "Okay", }) continue } else { if i == len(request.Messages)-1 { - prompt = message.Content + prompt = message.StringContent() break } messages = append(messages, AliMessage{ - User: message.Content, - Bot: request.Messages[i+1].Content, + User: message.StringContent(), + Bot: request.Messages[i+1].StringContent(), }) i++ } diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 01267fbf..89a311a0 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -5,11 +5,13 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" + "strings" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -95,13 +97,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + baseURL = c.GetString("base_url") + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + } + requestBody := c.Request.Body req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + req.Header.Set("api-key", apiKey) + req.ContentLength = c.Request.ContentLength + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index ed08ac04..c75ec09a 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { if message.Role == "system" { messages = append(messages, BaiduMessage{ Role: "user", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, BaiduMessage{ Role: "assistant", @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } else { messages = append(messages, BaiduMessage{ Role: message.Role, - Content: message.Content, + Content: message.StringContent(), }) } } diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 1f4a3e7b..1b72b47d 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { } else if message.Role == "assistant" { prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) } else if message.Role == "system" { - prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) + if prompt == "" { + prompt = message.StringContent() + } } } prompt += "\n\nAssistant:" diff --git a/controller/relay-image.go b/controller/relay-image.go index 16dbeb87..7ce860eb 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var imageRequest ImageRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } // Size validation @@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*imageCostRatio*1000) * imageRequest.N - if consumeQuota && userQuota-quota < 0 { + if userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } @@ -162,7 +159,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } if quota != 0 { tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f", modelRatio, groupRatio) + logContent := fmt.Sprintf("费用:0.002* %.2f", modelRatio) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") @@ -171,23 +168,21 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } }(c.Request.Context()) - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) + responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 6bdfbc08..37867843 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O return nil, responseText } -func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { var textResponse TextResponse - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. @@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, err := io.Copy(c.Writer, resp.Body) + _, err = io.Copy(c.Writer, resp.Body) if err != nil { return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } @@ -132,7 +131,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.Content, model) + completionTokens += countTokenText(choice.Message.StringContent(), model) } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-palm.go b/controller/relay-palm.go index a705b318..2bd0bcd8 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { } for _, message := range textRequest.Messages { palmMessage := PaLMChatMessage{ - Content: message.Content, + Content: message.StringContent(), } if message.Role == "user" { palmMessage.Author = "0" diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go index 024468bc..f66bf38f 100644 --- a/controller/relay-tencent.go +++ b/controller/relay-tencent.go @@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { if message.Role == "system" { messages = append(messages, TencentMessage{ Role: "user", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, TencentMessage{ Role: "assistant", @@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { continue } messages = append(messages, TencentMessage{ - Content: message.Content, + Content: message.StringContent(), Role: message.Role, }) } diff --git a/controller/relay-text.go b/controller/relay-text.go index 9e4739d5..063cb5be 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var textRequest GeneralOpenAIRequest - if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &textRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" @@ -147,7 +144,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0613") - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) } case APITypeClaude: fullRequestURL = "https://api.anthropic.com/v1/complete" @@ -233,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { preConsumedQuota = 0 common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } - if consumeQuota && preConsumedQuota > 0 { + if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) @@ -412,32 +411,31 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { - if consumeQuota { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - inputPrice := ratio * 0.002 + quota := 0 + completionRatio := common.GetCompletionRatio(textRequest.Model) + promptTokens = textResponse.Usage.PromptTokens + completionTokens = textResponse.Usage.CompletionTokens + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } + if quota != 0 { + inputPrice := ratio * 0.002 var logContent string if completionRatio == 1 { @@ -446,11 +444,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { outputPrice := inputPrice * completionRatio logContent = fmt.Sprintf("输入:$%.6g/1k tokens, 输出:$%.6g/1k tokens", inputPrice, outputPrice) } - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) } + }() }(c.Request.Context()) switch apiType { @@ -464,7 +462,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) return nil } else { - err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) + err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index dded706d..abb79896 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -4,14 +4,15 @@ import ( "context" "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "io" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" ) var stopFinishReason = "stop" @@ -86,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Content) + tokenNum += getTokenNum(tokenEncoder, message.StringContent()) tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName @@ -181,11 +182,16 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr func getFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channelType == common.ChannelTypeOpenAI { - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case common.ChannelTypeOpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case common.ChannelTypeAzure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } + return fullRequestURL } diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 91fb6042..00ec8981 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma if message.Role == "system" { messages = append(messages, XunfeiMessage{ Role: "user", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, XunfeiMessage{ Role: "assistant", @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma } else { messages = append(messages, XunfeiMessage{ Role: message.Role, - Content: message.Content, + Content: message.StringContent(), }) } } diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 7a4a582d..2e345ab5 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { if message.Role == "system" { messages = append(messages, ZhipuMessage{ Role: "system", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, ZhipuMessage{ Role: "user", @@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } else { messages = append(messages, ZhipuMessage{ Role: message.Role, - Content: message.Content, + Content: message.StringContent(), }) } } diff --git a/controller/relay.go b/controller/relay.go index cc126abd..74fd2262 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -12,10 +12,49 @@ import ( type Message struct { Role string `json:"role"` - Content string `json:"content"` + Content any `json:"content"` Name *string `json:"name,omitempty"` } +type ImageURL struct { + Url string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type TextContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` +} + +type ImageContent struct { + Type string `json:"type,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +func (m Message) StringContent() string { + content, ok := m.Content.(string) + if ok { + return content + } + contentList, ok := m.Content.([]any) + if ok { + var contentStr string + for _, contentItem := range contentList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + if contentMap["type"] == "text" { + if subStr, ok := contentMap["text"].(string); ok { + contentStr += subStr + } + } + } + return contentStr + } + return "" +} + const ( RelayModeUnknown = iota RelayModeChatCompletions @@ -31,19 +70,30 @@ const ( // https://platform.openai.com/docs/api-reference/chat +type ResponseFormat struct { + Type string `json:"type,omitempty"` +} + type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { @@ -201,9 +251,9 @@ func Relay(c *gin.Context) { relayMode = RelayModeEdits } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") { + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") { + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { relayMode = RelayModeAudioTranslation } var err *OpenAIErrorWithStatusCode diff --git a/middleware/auth.go b/middleware/auth.go index b0803612..ad7e64b7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 00000000..bbcd969c --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1,3 @@ +close #issue_number + +我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index b2dd2bba..0f4d99de 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -62,7 +62,7 @@ const EditChannel = () => { let localModels = []; switch (value) { case 14: - localModels = ['claude-instant-1', 'claude-2']; + localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; break; case 11: localModels = ['PaLM-2'];