Merge remote-tracking branch 'upstream/main' into fix_retry

This commit is contained in:
AhhhLiu 2023-12-02 11:20:48 +08:00
commit f81f53bca4
25 changed files with 523 additions and 236 deletions

View File

@ -51,15 +51,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
<a href="https://iamazing.cn/page/reward">赞赏支持</a> <a href="https://iamazing.cn/page/reward">赞赏支持</a>
</p> </p>
> **Note** > [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> >
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> **Warning** > [!WARNING]
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> **Warning** > [!WARNING]
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456` > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`
## 功能 ## 功能
@ -92,14 +92,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
12. 支持**用户邀请奖励**。 12. 支持**用户邀请奖励**。
13. 支持以美元为单位显示额度。 13. 支持以美元为单位显示额度。
14. 支持发布公告,设置充值链接,设置新用户初始额度。 14. 支持发布公告,设置充值链接,设置新用户初始额度。
15. 支持模型映射,重定向用户的请求模型。 15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功
16. 支持失败自动重试。 16. 支持失败自动重试。
17. 支持绘图接口。 17. 支持绘图接口。
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置, 19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。 1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
20. 支持通过系统访问令牌访问管理 API。 20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法)
21. 支持 Cloudflare Turnstile 用户校验。 21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式** 22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。

View File

@ -1,11 +1,13 @@
package common package common
import ( import (
"crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/smtp" "net/smtp"
"strings" "strings"
"time"
) )
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error {
SMTPFrom = SMTPAccount SMTPFrom = SMTPAccount
} }
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
}
// Generate a unique Message-ID
buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
return err
}
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
mail := []byte(fmt.Sprintf("To: %s\r\n"+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+ "From: %s<%s>\r\n"+
"Subject: %s\r\n"+ "Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, content)) receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
var err error
if SMTPPort == 465 { if SMTPPort == 465 {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"strings"
) )
func UnmarshalBodyReusable(c *gin.Context, v any) error { func UnmarshalBodyReusable(c *gin.Context, v any) error {
@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil { if err != nil {
return err return err
} }
err = json.Unmarshal(requestBody, &v) contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,6 +6,29 @@ import (
"time" "time"
) )
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
// ModelRatio // ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
@ -36,7 +59,11 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10, "text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,
"babbage": 10, "babbage": 10,
@ -45,9 +72,12 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"dall-e": 8, "dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens
"claude-2.0": 5.51, // $11.02 / 1M tokens
"claude-2.1": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens

View File

@ -5,14 +5,15 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
@ -43,16 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
} }
requestURL := common.ChannelBaseURLs[channel.Type] requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure { if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else { } else {
if channel.GetBaseURL() != "" { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = channel.GetBaseURL() requestURL = baseURL
} }
requestURL += "/v1/chat/completions"
}
// for Cloudflare AI gateway: https://github.com/songquanpeng/one-api/pull/639
requestURL = strings.Replace(requestURL, "/v1/v1", "/v1", 1)
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
jsonData, err := json.Marshal(request) jsonData, err := json.Marshal(request)
if err != nil { if err != nil {
return err, nil return err, nil
@ -73,10 +72,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
} }
defer resp.Body.Close() defer resp.Body.Close()
var response TextResponse var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err, nil return err, nil
} }
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
if response.Usage.CompletionTokens == 0 { if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
} }

View File

@ -55,12 +55,21 @@ func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{ openAIModels = []OpenAIModels{
{ {
Id: "dall-e", Id: "dall-e-2",
Object: "model", Object: "model",
Created: 1677649963, Created: 1677649963,
OwnedBy: "openai", OwnedBy: "openai",
Permission: permission, Permission: permission,
Root: "dall-e", Root: "dall-e-2",
Parent: nil,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-3",
Parent: nil, Parent: nil,
}, },
{ {
@ -72,6 +81,42 @@ func init() {
Root: "whisper-1", Root: "whisper-1",
Parent: nil, Parent: nil,
}, },
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{ {
Id: "gpt-3.5-turbo", Id: "gpt-3.5-turbo",
Object: "model", Object: "model",
@ -315,6 +360,24 @@ func init() {
Root: "claude-2", Root: "claude-2",
Parent: nil, Parent: nil,
}, },
{
Id: "claude-2.1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.1",
Parent: nil,
},
{
Id: "claude-2.0",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.0",
Parent: nil,
},
{ {
Id: "ERNIE-Bot", Id: "ERNIE-Bot",
Object: "model", Object: "model",

View File

@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := "" query := ""
if len(request.Messages) != 0 { if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content query = request.Messages[len(request.Messages)-1].StringContent()
} }
return &AIProxyLibraryRequest{ return &AIProxyLibraryRequest{
Model: request.Model, Model: request.Model,

View File

@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
message := request.Messages[i] message := request.Messages[i]
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, AliMessage{ messages = append(messages, AliMessage{
User: message.Content, User: message.StringContent(),
Bot: "Okay", Bot: "Okay",
}) })
continue continue
} else { } else {
if i == len(request.Messages)-1 { if i == len(request.Messages)-1 {
prompt = message.Content prompt = message.StringContent()
break break
} }
messages = append(messages, AliMessage{ messages = append(messages, AliMessage{
User: message.Content, User: message.StringContent(),
Bot: request.Messages[i+1].Content, Bot: request.Messages[i+1].StringContent(),
}) })
i++ i++
} }

View File

@ -11,6 +11,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strings"
) )
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -21,16 +22,41 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
tokenName := c.GetString("token_name")
var ttsRequest TextToSpeechRequest
if relayMode == RelayModeAudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
if err != nil {
return errorWrapper(err, "invalid_json", http.StatusBadRequest)
}
audioModel = ttsRequest.Model
// Check if text is too long 4096
if len(ttsRequest.Input) > 4096 {
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
}
}
preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel) modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group) groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) var quota int
var preConsumedQuota int
switch relayMode {
case RelayModeAudioSpeech:
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
}
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
// Check if user quota is enough
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
@ -70,13 +96,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
baseURL = c.GetString("base_url")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
}
requestBody := c.Request.Body requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
} }
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
req.Header.Set("api-key", apiKey)
req.ContentLength = c.Request.ContentLength
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@ -93,47 +139,44 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
var audioResponse AudioResponse
if relayMode != RelayModeAudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var whisperResponse WhisperResponse
err = json.Unmarshal(responseBody, &whisperResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
quota = countTokenText(whisperResponse.Text, audioModel)
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) { defer func(ctx context.Context) {
go func() { go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context()) }(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }

View File

@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: "assistant", Role: "assistant",
@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} else { } else {
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.StringContent(),
}) })
} }
} }

View File

@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
} else if message.Role == "assistant" { } else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" { } else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) if prompt == "" {
prompt = message.StringContent()
}
} }
} }
prompt += "\n\nAssistant:" prompt += "\n\nAssistant:"

View File

@ -6,44 +6,79 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin"
) )
func isWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}
min := common.DalleGenerationImageAmounts[element][0]
max := common.DalleGenerationImageAmounts[element][1]
return value >= min && value <= max
}
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e" imageModel := "dall-e-2"
imageSize := "1024x1024"
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
var imageRequest ImageRequest var imageRequest ImageRequest
if consumeQuota { err := common.UnmarshalBodyReusable(c, &imageRequest)
err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil {
if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) }
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
}
// Model validation
if imageRequest.Model != "" {
imageModel = imageRequest.Model
}
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
// Check if model is supported
if hasValidSize {
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
if imageSize == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
} }
} else {
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
} }
// Prompt validation // Prompt validation
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
} }
// Not "256x256", "512x512", or "1024x1024" // Check prompt length
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
} }
// N should between 1 and 10 // Number of generated images validation
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { if isWithinRange(imageModel, imageRequest.N) == false {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
} }
// map model name // map model name
@ -82,18 +117,9 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0 quota := int(ratio*imageCostRatio*1000) * imageRequest.N
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 { if userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
@ -122,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse var textResponse ImageResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota)
err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil {
if err != nil { common.SysError("error consuming token remain quota: " + err.Error())
common.SysError("error consuming token remain quota: " + err.Error()) }
} err = model.CacheUpdateUserQuota(userId)
err = model.CacheUpdateUserQuota(userId) if err != nil {
if err != nil { common.SysError("error update user quota cache: " + err.Error())
common.SysError("error update user quota cache: " + err.Error()) }
} if quota != 0 {
if quota != 0 { tokenName := c.GetString("token_name")
tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id")
channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
} }
}(c.Request.Context()) }(c.Request.Context())
if consumeQuota { responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])

View File

@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText return nil, responseText
} }
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse var textResponse TextResponse
if consumeQuota { responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body) if err != nil {
if err != nil { return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set. // And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response. // So the httpClient will be confused by the response.
@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
} }
@ -132,7 +131,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if textResponse.Usage.TotalTokens == 0 { if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model) completionTokens += countTokenText(choice.Message.StringContent(), model)
} }
textResponse.Usage = Usage{ textResponse.Usage = Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,

View File

@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
} }
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{ palmMessage := PaLMChatMessage{
Content: message.Content, Content: message.StringContent(),
} }
if message.Role == "user" { if message.Role == "user" {
palmMessage.Author = "0" palmMessage.Author = "0"

View File

@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Role: "assistant", Role: "assistant",
@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
continue continue
} }
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Content: message.Content, Content: message.StringContent(),
Role: message.Role, Role: message.Role,
}) })
} }

View File

@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
var textRequest GeneralOpenAIRequest var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { err := common.UnmarshalBodyReusable(c, &textRequest)
err := common.UnmarshalBodyReusable(c, &textRequest) if err != nil {
if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} }
if relayMode == RelayModeModerations && textRequest.Model == "" { if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest" textRequest.Model = "text-moderation-latest"
@ -147,7 +144,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613") model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
} }
case APITypeClaude: case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete" fullRequestURL = "https://api.anthropic.com/v1/complete"
@ -233,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
} }
if consumeQuota && preConsumedQuota > 0 { if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
@ -367,6 +366,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
case APITypeTencent: case APITypeTencent:
req.Header.Set("Authorization", apiKey) req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
default: default:
req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Authorization", "Bearer "+apiKey)
} }
@ -410,37 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) { defer func(ctx context.Context) {
// c.Writer.Flush() // c.Writer.Flush()
go func() { go func() {
if consumeQuota { quota := 0
quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model)
completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens
promptTokens = textResponse.Usage.PromptTokens completionTokens = textResponse.Usage.CompletionTokens
completionTokens = textResponse.Usage.CompletionTokens quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 {
if ratio != 0 && quota <= 0 { quota = 1
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
} }
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}() }()
}(c.Request.Context()) }(c.Request.Context())
switch apiType { switch apiType {
@ -454,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil return nil
} else { } else {
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,15 +1,18 @@
package controller package controller
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
) )
var stopFinishReason = "stop" var stopFinishReason = "stop"
@ -84,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int {
tokenNum := 0 tokenNum := 0
for _, message := range messages { for _, message := range messages {
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content) tokenNum += getTokenNum(tokenEncoder, message.StringContent())
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil { if message.Name != nil {
tokenNum += tokensPerName tokenNum += tokensPerName
@ -179,10 +182,37 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
func getFullRequestURL(baseURL string, requestURL string, channelType int) string { func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
} }
} }
return fullRequestURL return fullRequestURL
} }
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}

View File

@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: "assistant", Role: "assistant",
@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
} else { } else {
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.StringContent(),
}) })
} }
} }

View File

@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: "system", Role: "system",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: "user", Role: "user",
@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} else { } else {
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.StringContent(),
}) })
} }
} }

View File

@ -13,10 +13,49 @@ import (
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content any `json:"content"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
} }
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == "text" {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
const ( const (
RelayModeUnknown = iota RelayModeUnknown = iota
RelayModeChatCompletions RelayModeChatCompletions
@ -25,24 +64,37 @@ const (
RelayModeModerations RelayModeModerations
RelayModeImagesGenerations RelayModeImagesGenerations
RelayModeEdits RelayModeEdits
RelayModeAudio RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Input any `json:"input,omitempty"` Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"` Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"` Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
} }
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {
@ -78,16 +130,30 @@ type TextRequest struct {
//Stream bool `json:"stream"` //Stream bool `json:"stream"`
} }
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct { type ImageRequest struct {
Prompt string `json:"prompt"` Model string `json:"model"`
N int `json:"n"` Prompt string `json:"prompt" binding:"required"`
Size string `json:"size"` N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality"`
ResponseFormat string `json:"response_format"`
Style string `json:"style"`
User string `json:"user"`
} }
type AudioResponse struct { type WhisperResponse struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
} }
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
@ -184,14 +250,22 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudio relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
} }
var err *OpenAIErrorWithStatusCode var err *OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case RelayModeImagesGenerations: case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode) err = relayImageHelper(c, relayMode)
case RelayModeAudio: case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode) err = relayAudioHelper(c, relayMode)
default: default:
err = relayTextHelper(c, relayMode) err = relayTextHelper(c, relayMode)

View File

@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])

View File

@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) {
} else { } else {
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest var modelRequest ModelRequest
var err error err := common.UnmarshalBodyReusable(c, &modelRequest)
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求") abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return return
@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) {
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "dall-e" modelRequest.Model = "dall-e-2"
} }
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "whisper-1" modelRequest.Model = "whisper-1"
} }

3
pull_request_template.md Normal file
View File

@ -0,0 +1,3 @@
close #issue_number
我已确认该 PR 已自测通过,相关截图如下:

View File

@ -30,6 +30,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)

View File

@ -60,7 +60,7 @@ const EditChannel = () => {
let localModels = []; let localModels = [];
switch (value) { switch (value) {
case 14: case 14:
localModels = ['claude-instant-1', 'claude-2']; localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'];
break; break;
case 11: case 11:
localModels = ['PaLM-2']; localModels = ['PaLM-2'];