Merge branch 'main' into main

This commit is contained in:
JustSong 2023-11-24 20:55:14 +08:00 committed by GitHub
commit a983f0e31f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 108 deletions

View File

@ -1,6 +1,7 @@
package common package common
import ( import (
"crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -14,16 +15,32 @@ func SendEmail(subject string, receiver string, content string) error {
SMTPFrom = SMTPAccount SMTPFrom = SMTPAccount
} }
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
}
// Generate a unique Message-ID
buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
return err
}
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
mail := []byte(fmt.Sprintf("To: %s\r\n"+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+ "From: %s<%s>\r\n"+
"Subject: %s\r\n"+ "Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+ "Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content)) receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
var err error
if SMTPPort == 465 { if SMTPPort == 465 {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,

View File

@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
var imageRequest ImageRequest var imageRequest ImageRequest
if consumeQuota { err := common.UnmarshalBodyReusable(c, &imageRequest)
err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil {
if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} }
// Size validation // Size validation
@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
quota := int(ratio*imageCostRatio*1000) * imageRequest.N quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 { if userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
@ -151,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse var textResponse ImageResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota)
err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil {
if err != nil { common.SysError("error consuming token remain quota: " + err.Error())
common.SysError("error consuming token remain quota: " + err.Error()) }
} err = model.CacheUpdateUserQuota(userId)
err = model.CacheUpdateUserQuota(userId) if err != nil {
if err != nil { common.SysError("error update user quota cache: " + err.Error())
common.SysError("error update user quota cache: " + err.Error()) }
} if quota != 0 {
if quota != 0 { tokenName := c.GetString("token_name")
tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id")
channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
} }
}(c.Request.Context()) }(c.Request.Context())
if consumeQuota { responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])

View File

@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText return nil, responseText
} }
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse var textResponse TextResponse
if consumeQuota { responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body) if err != nil {
if err != nil { return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set. // And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response. // So the httpClient will be confused by the response.
@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
} }

View File

@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
var textRequest GeneralOpenAIRequest var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { err := common.UnmarshalBodyReusable(c, &textRequest)
err := common.UnmarshalBodyReusable(c, &textRequest) if err != nil {
if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} }
if relayMode == RelayModeModerations && textRequest.Model == "" { if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest" textRequest.Model = "text-moderation-latest"
@ -235,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
} }
if consumeQuota && preConsumedQuota > 0 { if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
@ -414,37 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) { defer func(ctx context.Context) {
// c.Writer.Flush() // c.Writer.Flush()
go func() { go func() {
if consumeQuota { quota := 0
quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model)
completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens
promptTokens = textResponse.Usage.PromptTokens completionTokens = textResponse.Usage.CompletionTokens
completionTokens = textResponse.Usage.CompletionTokens quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 {
if ratio != 0 && quota <= 0 { quota = 1
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
} }
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}() }()
}(c.Request.Context()) }(c.Request.Context())
switch apiType { switch apiType {
@ -458,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil return nil
} else { } else {
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil { if err != nil {
return err return err
} }

View File

@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])