diff --git a/.gitignore b/.gitignore index 1b2cf071..60abb13e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ upload *.db build *.db-journal -logs \ No newline at end of file +logs +data \ No newline at end of file diff --git a/README.en.md b/README.en.md index 783c140c..9345a219 100644 --- a/README.en.md +++ b/README.en.md @@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. First, fork the code. 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). diff --git a/README.ja.md b/README.ja.md index fa3339c2..6faf9bee 100644 --- a/README.ja.md +++ b/README.ja.md @@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. まず、コードをフォークする。 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 diff --git a/README.md b/README.md index 79229a94..4ef6505c 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [CloseAI](https://console.closeai-asia.com/r/2412) + + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) + [x] [API2D](https://api2d.com/r/197971) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) @@ -160,6 +160,19 @@ sudo service nginx restart 初始账号用户名为 `root`,密码为 `123456`。 + +### 基于 Docker Compose 进行部署 + +> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 + +```shell +# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 +docker-compose up -d + +# 查看部署状态 +docker-compose ps +``` + ### 手动部署 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: ```shell @@ -249,6 +262,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. 首先 fork 一份代码。 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 diff --git a/common/model-ratio.go b/common/model-ratio.go index 8f4be8c3..74c74a90 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -3,8 +3,32 @@ package common import ( "encoding/json" "strings" + "time" ) +var DalleSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, +} + +var DalleGenerationImageAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. +} + +var DalleImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, +} + // ModelRatio // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf @@ -19,12 +43,15 @@ var ModelRatio = map[string]float64{ "gpt-4-32k": 30, "gpt-4-32k-0314": 30, "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "text-ada-001": 0.2, "text-babbage-001": 0.25, "text-curie-001": 1, @@ -32,7 +59,11 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, "davinci": 10, "curie": 10, "babbage": 10, @@ -41,7 +72,8 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e": 8, + "dall-e-2": 8, // $0.016 - $0.020 / image + "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens @@ -88,9 +120,24 @@ func GetModelRatio(name string) float64 { func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "gpt-3.5") { + if strings.HasSuffix(name, "1106") { + return 2 + } + if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { + // TODO: clear this after 2023-12-11 + now := time.Now() + // https://platform.openai.com/docs/models/continuous-model-upgrades + // if after 2023-12-11, use 2 + if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { + return 2 + } + } return 1.333333 } if strings.HasPrefix(name, "gpt-4") { + if strings.HasSuffix(name, "preview") { + return 3 + } return 2 } if strings.HasPrefix(name, "claude-instant-1") { diff --git a/controller/channel-test.go b/controller/channel-test.go index c5d4d26c..1b0b745a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "one-api/common" "one-api/model" @@ -45,8 +46,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai if channel.Type == common.ChannelTypeAzure { requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) } else { - if channel.GetBaseURL() != "" { - requestURL = channel.GetBaseURL() + if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { + requestURL = baseURL } requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) @@ -71,10 +72,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } defer resp.Body.Close() var response TextResponse - err = json.NewDecoder(resp.Body).Decode(&response) + body, err := io.ReadAll(resp.Body) if err != nil { return err, nil } + err = json.Unmarshal(body, &response) + if err != nil { + return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil + } if response.Usage.CompletionTokens == 0 { return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error } diff --git a/controller/model.go b/controller/model.go index 2a7dc538..59ea22e8 100644 --- a/controller/model.go +++ b/controller/model.go @@ -55,12 +55,21 @@ func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility openAIModels = []OpenAIModels{ { - Id: "dall-e", + Id: "dall-e-2", Object: "model", Created: 1677649963, OwnedBy: "openai", Permission: permission, - Root: "dall-e", + Root: "dall-e-2", + Parent: nil, + }, + { + Id: "dall-e-3", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "dall-e-3", Parent: nil, }, { @@ -72,6 +81,42 @@ func init() { Root: "whisper-1", Parent: nil, }, + { + Id: "tts-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1", + Parent: nil, + }, + { + Id: "tts-1-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-1106", + Parent: nil, + }, + { + Id: "tts-1-hd", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd", + Parent: nil, + }, + { + Id: "tts-1-hd-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", @@ -117,6 +162,15 @@ func init() { Root: "gpt-3.5-turbo-16k-0613", Parent: nil, }, + { + Id: "gpt-3.5-turbo-1106", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo-instruct", Object: "model", @@ -180,6 +234,24 @@ func init() { Root: "gpt-4-32k-0613", Parent: nil, }, + { + Id: "gpt-4-1106-preview", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-1106-preview", + Parent: nil, + }, + { + Id: "gpt-4-vision-preview", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-vision-preview", + Parent: nil, + }, { Id: "text-embedding-ada-002", Object: "model", diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 53833108..01267fbf 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "github.com/gin-gonic/gin" "io" "net/http" @@ -21,6 +20,22 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelId := c.GetInt("channel_id") userId := c.GetInt("id") group := c.GetString("group") + tokenName := c.GetString("token_name") + + var ttsRequest TextToSpeechRequest + if relayMode == RelayModeAudioSpeech { + // Read JSON + err := common.UnmarshalBodyReusable(c, &ttsRequest) + // Check if JSON is valid + if err != nil { + return errorWrapper(err, "invalid_json", http.StatusBadRequest) + } + audioModel = ttsRequest.Model + // Check if text is too long 4096 + if len(ttsRequest.Input) > 4096 { + return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + } + } preConsumedTokens := common.PreConsumedQuota modelRatio := common.GetModelRatio(audioModel) @@ -31,22 +46,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + + quota := 0 + // Check if user quota is enough + if relayMode == RelayModeAudioSpeech { + quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) + if quota > userQuota { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + } else { + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } } } @@ -93,47 +118,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var audioResponse AudioResponse - defer func(ctx context.Context) { - go func() { - quota := countTokenText(audioResponse.Text, audioModel) + if relayMode == RelayModeAudioSpeech { + defer func(ctx context.Context) { + go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + } else { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + var whisperResponse WhisperResponse + err = json.Unmarshal(responseBody, &whisperResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + defer func(ctx context.Context) { + quota := countTokenText(whisperResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } diff --git a/controller/relay-image.go b/controller/relay-image.go index ccd52dce..1d1b71ba 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -6,15 +6,28 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) +func isWithinRange(element string, value int) bool { + if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + return false + } + + min := common.DalleGenerationImageAmounts[element][0] + max := common.DalleGenerationImageAmounts[element][1] + + return value >= min && value <= max +} + func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - imageModel := "dall-e" + imageModel := "dall-e-2" + imageSize := "1024x1024" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -31,19 +44,44 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } + // Size validation + if imageRequest.Size != "" { + imageSize = imageRequest.Size + } + + // Model validation + if imageRequest.Model != "" { + imageModel = imageRequest.Model + } + + imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] + + // Check if model is supported + if hasValidSize { + if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { + if imageSize == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + } else { + return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) + // Check prompt length + if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { + return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } - // N should between 1 and 10 - if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + // Number of generated images validation + if isWithinRange(imageModel, imageRequest.N) == false { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } // map model name @@ -82,16 +120,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) - sizeRatio := 1.0 - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 1 - } else if imageRequest.Size == "512x512" { - sizeRatio = 1.125 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1.25 - } - quota := int(ratio*sizeRatio*1000) * imageRequest.N + quota := int(ratio*imageCostRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) diff --git a/controller/relay-text.go b/controller/relay-text.go index 4806229e..018c8d8a 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -369,6 +369,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeTencent: req.Header.Set("Authorization", apiKey) + case APITypePaLM: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index e888efa7..e2b77a97 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,11 +1,13 @@ package controller import ( + "context" "encoding/json" "fmt" "io" "net/http" "one-api/common" + "one-api/model" "strconv" "strings" @@ -192,3 +194,20 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } + +func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) + } +} diff --git a/controller/relay.go b/controller/relay.go index 1926110e..863267b4 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,7 +24,9 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits - RelayModeAudio + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation ) // https://platform.openai.com/docs/api-reference/chat @@ -77,16 +79,30 @@ type TextRequest struct { //Stream bool `json:"stream"` } +// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create type ImageRequest struct { - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality"` + ResponseFormat string `json:"response_format"` + Style string `json:"style"` + User string `json:"user"` } -type AudioResponse struct { +type WhisperResponse struct { Text string `json:"text,omitempty"` } +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -183,14 +199,22 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - relayMode = RelayModeAudio + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") { + relayMode = RelayModeAudioTranslation } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) - case RelayModeAudio: + case RelayModeAudioSpeech: + fallthrough + case RelayModeAudioTranslation: + fallthrough + case RelayModeAudioTranscription: err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode) diff --git a/docker-compose.yml b/docker-compose.yml index 9b814a03..30edb281 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,19 +9,19 @@ services: ports: - "3000:3000" volumes: - - ./data:/data + - ./data/oneapi:/data - ./logs:/app/logs environment: - - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 + - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - REDIS_CONN_STRING=redis://redis - SESSION_SECRET=random_string # 修改为随机字符串 - TZ=Asia/Shanghai # - NODE_TYPE=slave # 多机部署时从节点取消注释该行 # - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 # - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 - depends_on: - redis + - db healthcheck: test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] interval: 30s @@ -32,3 +32,18 @@ services: image: redis:latest container_name: redis restart: always + + db: + image: mysql:8.2.0 + restart: always + container_name: mysql + volumes: + - ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储 + ports: + - '3306:3306' + environment: + TZ: Asia/Shanghai # 设置时区 + MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码 + MYSQL_USER: oneapi # 创建专用用户 + MYSQL_PASSWORD: '123456' # 设置专用用户密码 + MYSQL_DATABASE: one-api # 自动创建数据库 \ No newline at end of file diff --git a/middleware/distributor.go b/middleware/distributor.go index d80945fc..c4ddc3a0 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var modelRequest ModelRequest - var err error - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } + err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { abortWithMessage(c, http.StatusBadRequest, "无效的请求") return @@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) { } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if modelRequest.Model == "" { - modelRequest.Model = "dall-e" + modelRequest.Model = "dall-e-2" } } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } diff --git a/model/log.go b/model/log.go index d26da9a2..3d3ffae3 100644 --- a/model/log.go +++ b/model/log.go @@ -94,7 +94,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName tx = tx.Where("created_at <= ?", endTimestamp) } if channel != 0 { - tx = tx.Where("channel = ?", channel) + tx = tx.Where("channel_id = ?", channel) } err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error return logs, err @@ -151,7 +151,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa tx = tx.Where("model_name = ?", modelName) } if channel != 0 { - tx = tx.Where("channel = ?", channel) + tx = tx.Where("channel_id = ?", channel) } tx.Where("type = ?", LogTypeConsume).Scan("a) return quota diff --git a/router/relay-router.go b/router/relay-router.go index e84f02db..912f4989 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.POST("/audio/speech", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 732189cb..d44ea2d7 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -286,17 +286,15 @@ const ChannelsTable = () => { if (channels.length === 0) return; setLoading(true); let sortedChannels = [...channels]; - if (typeof sortedChannels[0][key] === 'string') { - sortedChannels.sort((a, b) => { + sortedChannels.sort((a, b) => { + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings return ('' + a[key]).localeCompare(b[key]); - }); - } else { - sortedChannels.sort((a, b) => { - if (a[key] === b[key]) return 0; - if (a[key] > b[key]) return -1; - if (a[key] < b[key]) return 1; - }); - } + } + }); if (sortedChannels[0].id === channels[0].id) { sortedChannels.reverse(); } @@ -304,6 +302,7 @@ const ChannelsTable = () => { setLoading(false); }; + return ( <>
diff --git a/web/src/components/RedemptionsTable.js b/web/src/components/RedemptionsTable.js index ae8b5b03..dfd59685 100644 --- a/web/src/components/RedemptionsTable.js +++ b/web/src/components/RedemptionsTable.js @@ -130,7 +130,13 @@ const RedemptionsTable = () => { setLoading(true); let sortedRedemptions = [...redemptions]; sortedRedemptions.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedRedemptions[0].id === redemptions[0].id) { sortedRedemptions.reverse(); diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index a3bb6f91..db4745e4 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -228,7 +228,13 @@ const TokensTable = () => { setLoading(true); let sortedTokens = [...tokens]; sortedTokens.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedTokens[0].id === tokens[0].id) { sortedTokens.reverse(); diff --git a/web/src/components/UsersTable.js b/web/src/components/UsersTable.js index f8fb0a75..ad4e9b49 100644 --- a/web/src/components/UsersTable.js +++ b/web/src/components/UsersTable.js @@ -133,7 +133,13 @@ const UsersTable = () => { setLoading(true); let sortedUsers = [...users]; sortedUsers.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedUsers[0].id === users[0].id) { sortedUsers.reverse();