From ac7c0f3a76c50632f9fcaaac642cfa3465e5074d Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 26 Aug 2023 12:05:18 +0800 Subject: [PATCH 01/21] fix: disable channel when 401 received (close #467) --- controller/channel-test.go | 2 +- controller/relay-utils.go | 6 +++++- controller/relay.go | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 8465d51d..4acb2e3b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -174,7 +174,7 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if shouldDisableChannel(openaiErr) { + if shouldDisableChannel(openaiErr, -1) { disableChannel(channel.Id, channel.Name, err.Error()) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 5b3e0274..aaf579ab 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "net/http" "one-api/common" ) @@ -95,13 +96,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus } } -func shouldDisableChannel(err *OpenAIError) bool { +func shouldDisableChannel(err *OpenAIError, statusCode int) bool { if !common.AutomaticDisableChannelEnabled { return false } if err == nil { return false } + if statusCode == http.StatusUnauthorized { + return true + } if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } diff --git a/controller/relay.go b/controller/relay.go index 86f16c45..1eaa2c26 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -185,7 +185,7 @@ func Relay(c *gin.Context) { channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError) { + if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) From a3e267df7eb1fc88a5b8332c8338965ba1782042 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 26 Aug 2023 12:37:45 +0800 Subject: [PATCH 02/21] fix: fix error response (close #468) --- controller/relay-text.go | 3 +-- controller/relay-utils.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index 0bad948f..5298d292 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -317,8 +317,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { - return errorWrapper( - fmt.Errorf("bad status code: %d", resp.StatusCode), "bad_status_code", resp.StatusCode) + return relayErrorHandler(resp) } } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index aaf579ab..1a9ee0d1 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,11 +1,14 @@ package controller import ( + "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "io" "net/http" "one-api/common" + "strconv" ) var stopFinishReason = "stop" @@ -119,3 +122,30 @@ func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("X-Accel-Buffering", "no") } + +func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { + openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: OpenAIError{ + Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Type: "one_api_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var textResponse TextResponse + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return + } + openAIErrorWithStatusCode.OpenAIError = textResponse.Error + return +} From fdb2cccf65f3d27f56a05b91c48aed393c7ed9c9 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 26 Aug 2023 13:02:02 +0800 Subject: [PATCH 03/21] perf: initialize all token encoder when starting (close #459, close $460) --- controller/relay-utils.go | 18 ++++++++++++++++++ main.go | 1 + 2 files changed, 19 insertions(+) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 1a9ee0d1..9010d275 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -15,6 +15,24 @@ var stopFinishReason = "stop" var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} +func InitTokenEncoders() { + common.SysLog("initializing token encoders") + fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) + } + for model, _ := range common.ModelRatio { + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) + tokenEncoderMap[model] = fallbackTokenEncoder + continue + } + tokenEncoderMap[model] = tokenEncoder + } + common.SysLog("token encoders initialized") +} + func getTokenEncoder(model string) *tiktoken.Tiktoken { if tokenEncoder, ok := tokenEncoderMap[model]; ok { return tokenEncoder diff --git a/main.go b/main.go index f4d20373..9fb0a73e 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,7 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + controller.InitTokenEncoders() // Initialize HTTP server server := gin.Default() From 4f2f911e4d2cd12b1de037d8bef803b1cfd89f61 Mon Sep 17 00:00:00 2001 From: shao0222 <22172112+shao0222@users.noreply.github.com> Date: Sat, 26 Aug 2023 13:10:18 +0800 Subject: [PATCH 04/21] fix: fix the issue of function_call not working when using model mapping (#462) --- controller/relay.go | 1 + 1 file changed, 1 insertion(+) diff --git a/controller/relay.go b/controller/relay.go index 1eaa2c26..6a2d58eb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -40,6 +40,7 @@ type GeneralOpenAIRequest struct { Input any `json:"input,omitempty"` Instruction string `json:"instruction,omitempty"` Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` } type ChatRequest struct { From 5ee24e8acfbb6f9c57bb6ecedaad995ac45201ef Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 26 Aug 2023 13:30:21 +0800 Subject: [PATCH 05/21] feat: support 360's models (close #331, close #461) feat: support 360's models (close #331, close #461) --- README.md | 1 + common/constants.go | 2 + common/model-ratio.go | 85 ++++++++++++++------------ controller/channel-test.go | 4 ++ controller/model.go | 45 ++++++++++++++ web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 3 + 7 files changed, 101 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 36841c79..45c8b603 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + + [x] [360 智脑](https://ai.360.cn) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) + [x] [API2D](https://api2d.com/r/197971) diff --git a/common/constants.go b/common/constants.go index 4b9df311..0ae8ae83 100644 --- a/common/constants.go +++ b/common/constants.go @@ -173,6 +173,7 @@ const ( ChannelTypeZhipu = 16 ChannelTypeAli = 17 ChannelTypeXunfei = 18 + ChannelType360 = 19 ) var ChannelBaseURLs = []string{ @@ -195,4 +196,5 @@ var ChannelBaseURLs = []string{ "https://open.bigmodel.cn", // 16 "https://dashscope.aliyuncs.com", // 17 "", // 18 + "https://ai.360.cn", // 19 } diff --git a/common/model-ratio.go b/common/model-ratio.go index e658cdc1..3f4f64b7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -13,46 +13,51 @@ import ( // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, - "gpt-3.5-turbo-0613": 0.75, - "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens - "gpt-3.5-turbo-16k-0613": 1.5, - "text-ada-001": 0.2, - "text-babbage-001": 0.25, - "text-curie-001": 1, - "text-davinci-002": 10, - "text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, - "whisper-1": 10, - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-ada-002": 0.05, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "dall-e": 8, - "claude-instant-1": 0.815, // $1.63 / 1M tokens - "claude-2": 5.51, // $11.02 / 1M tokens - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "PaLM-2": 1, - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag - "qwen-plus-v1": 0.5715, // Same as above - "SparkDesk": 0.8572, // TBD + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 10, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e": 8, + "claude-instant-1": 0.815, // $1.63 / 1M tokens + "claude-2": 5.51, // $11.02 / 1M tokens + "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens + "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "PaLM-2": 1, + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag + "qwen-plus-v1": 0.5715, // Same as above + "SparkDesk": 0.8572, // TBD + "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens + "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens + "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens } func ModelRatio2JSONString() string { diff --git a/controller/channel-test.go b/controller/channel-test.go index 4acb2e3b..686521ef 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -24,6 +24,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr fallthrough case common.ChannelTypeZhipu: fallthrough + case common.ChannelTypeAli: + fallthrough + case common.ChannelType360: + fallthrough case common.ChannelTypeXunfei: return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: diff --git a/controller/model.go b/controller/model.go index c68aa50c..a8ac6a65 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,51 @@ func init() { Root: "SparkDesk", Parent: nil, }, + { + Id: "360GPT_S2_V9", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "360GPT_S2_V9", + Parent: nil, + }, + { + Id: "embedding-bert-512-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "embedding-bert-512-v1", + Parent: nil, + }, + { + Id: "embedding_s1_v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "embedding_s1_v1", + Parent: nil, + }, + { + Id: "semantic_similarity_s1_v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "semantic_similarity_s1_v1", + Parent: nil, + }, + { + Id: "360GPT_S2_V9.4", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "360GPT_S2_V9.4", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index a17ef374..a14c4e0f 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -7,6 +7,7 @@ export const CHANNEL_OPTIONS = [ { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, + { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index fcbdb980..5d8951a1 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -61,6 +61,9 @@ const EditChannel = () => { case 18: localModels = ['SparkDesk']; break; + case 19: + localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'] + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } From d09d317459e8749b28622f7c3bd710dc9821819f Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 27 Aug 2023 15:28:23 +0800 Subject: [PATCH 06/21] feat: supper whisper now (close #197) --- common/model-ratio.go | 2 +- controller/model.go | 9 +++ controller/relay-audio.go | 147 ++++++++++++++++++++++++++++++++++++++ controller/relay.go | 9 +++ middleware/distributor.go | 10 ++- router/relay-router.go | 4 +- 6 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 controller/relay-audio.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 3f4f64b7..70758805 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -31,7 +31,7 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "davinci": 10, "curie": 10, "babbage": 10, diff --git a/controller/model.go b/controller/model.go index a8ac6a65..88f95f7b 100644 --- a/controller/model.go +++ b/controller/model.go @@ -63,6 +63,15 @@ func init() { Root: "dall-e", Parent: nil, }, + { + Id: "whisper-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "whisper-1", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", diff --git a/controller/relay-audio.go b/controller/relay-audio.go new file mode 100644 index 00000000..277ab404 --- /dev/null +++ b/controller/relay-audio.go @@ -0,0 +1,147 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + audioModel := "whisper-1" + + tokenId := c.GetInt("token_id") + channelType := c.GetInt("channel") + userId := c.GetInt("id") + group := c.GetString("group") + + preConsumedTokens := common.PreConsumedQuota + modelRatio := common.GetModelRatio(audioModel) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(preConsumedTokens) * ratio) + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + + // map model name + modelMapping := c.GetString("model_mapping") + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[audioModel] != "" { + audioModel = modelMap[audioModel] + } + } + + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + requestBody := c.Request.Body + + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + + resp, err := httpClient.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + var audioResponse AudioResponse + + defer func() { + go func() { + quota := countTokenText(audioResponse.Text, audioModel) + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + }() + }() + + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &audioResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + return nil +} diff --git a/controller/relay.go b/controller/relay.go index 6a2d58eb..056d42d3 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,6 +24,7 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits + RelayModeAudio ) // https://platform.openai.com/docs/api-reference/chat @@ -63,6 +64,10 @@ type ImageRequest struct { Size string `json:"size"` } +type AudioResponse struct { + Text string `json:"text,omitempty"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -159,11 +164,15 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + relayMode = RelayModeAudio } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) + case RelayModeAudio: + err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode) } diff --git a/middleware/distributor.go b/middleware/distributor.go index ebbde535..93827c95 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -58,7 +58,10 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + var err error + if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ @@ -84,6 +87,11 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "dall-e" } } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) diff --git a/router/relay-router.go b/router/relay-router.go index c3c84d8b..a76e42cf 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay) - relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) - relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) + relayV1Router.POST("/audio/transcriptions", controller.Relay) + relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) From 56b50073798e3930cbccceaf17b6e2a6a68524b8 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 27 Aug 2023 16:16:45 +0800 Subject: [PATCH 07/21] feat: supper OpenRouter now (close #333, close #340) --- common/constants.go | 42 ++++++++++++++------------ controller/relay-text.go | 4 +++ web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 6 ++-- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/common/constants.go b/common/constants.go index 0ae8ae83..e5211e3d 100644 --- a/common/constants.go +++ b/common/constants.go @@ -154,26 +154,27 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeAPI2D = 2 + ChannelTypeAzure = 3 + ChannelTypeCloseAI = 4 + ChannelTypeOpenAISB = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 ) var ChannelBaseURLs = []string{ @@ -197,4 +198,5 @@ var ChannelBaseURLs = []string{ "https://dashscope.aliyuncs.com", // 17 "", // 18 "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 } diff --git a/controller/relay-text.go b/controller/relay-text.go index 5298d292..624b9d01 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -282,6 +282,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("api-key", apiKey) } else { req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + if channelType == common.ChannelTypeOpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } } case APITypeClaude: req.Header.Set("x-api-key", apiKey) diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index a14c4e0f..b1631479 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [ { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 5d8951a1..da11b588 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -1,6 +1,6 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; -import { useParams, useNavigate } from 'react-router-dom'; +import { useNavigate, useParams } from 'react-router-dom'; import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; import { CHANNEL_OPTIONS } from '../../constants'; @@ -19,7 +19,7 @@ const EditChannel = () => { const handleCancel = () => { navigate('/channel'); }; - + const originInputs = { name: '', type: 1, @@ -62,7 +62,7 @@ const EditChannel = () => { localModels = ['SparkDesk']; break; case 19: - localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'] + localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; break; } setInputs((inputs) => ({ ...inputs, models: localModels })); From ef2c5abb5b3d1fc5bffdb358758156e790db38f8 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 30 Aug 2023 20:51:37 +0800 Subject: [PATCH 08/21] docs: update README --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 45c8b603..a0f3bcb9 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 +如果启动失败,请添加 `--privileged=true`,具体参考 #482。 + 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 @@ -275,8 +277,9 @@ graph LR 不加的话将会使用负载均衡的方式使用多个渠道。 ### 环境变量 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 + 例子:`SESSION_SECRET=random_string` 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 @@ -339,6 +342,7 @@ https://openai.justsong.cn 5. ChatGPT Next Web 报错:`Failed to fetch` + 部署的时候不要设置 `BASE_URL`。 + 检查你的接口地址和 API Key 有没有填对。 + + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 6. 报错:`当前分组负载已饱和,请稍后再试` + 上游通道 429 了。 From abbf2fded0a694390e0f025b63e757d23b86155c Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 30 Aug 2023 21:15:56 +0800 Subject: [PATCH 09/21] perf: preallocate array capacity --- controller/channel.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/channel.go b/controller/channel.go index 8afc0eed..50b2b5f6 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) { } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") - channels := make([]model.Channel, 0) + channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue From f0d5e102a3dc22289e00860f63ebd0f125531641 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 30 Aug 2023 21:43:01 +0800 Subject: [PATCH 10/21] fix: fix log table use created_at as key instead of id Co-authored-by: 13714733197 <13714733197@163.com> --- web/src/components/LogsTable.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index bacb7689..c981e261 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -324,7 +324,7 @@ const LogsTable = () => { .map((log, idx) => { if (log.deleted) return <>; return ( - + {renderTimestamp(log.created_at)} { isAdminUser && ( From 04acdb1ccb059d7dd86f4f397f4a027259edc74d Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 12:51:59 +0800 Subject: [PATCH 11/21] feat: support aiproxy's library --- common/constants.go | 44 ++--- controller/relay-aiproxy.go | 220 +++++++++++++++++++++++++ controller/relay-text.go | 35 ++++ middleware/distributor.go | 7 +- web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 14 ++ 6 files changed, 299 insertions(+), 22 deletions(-) create mode 100644 controller/relay-aiproxy.go diff --git a/common/constants.go b/common/constants.go index e5211e3d..66ca06f4 100644 --- a/common/constants.go +++ b/common/constants.go @@ -154,27 +154,28 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeAPI2D = 2 + ChannelTypeAzure = 3 + ChannelTypeCloseAI = 4 + ChannelTypeOpenAISB = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 + ChannelTypeAIProxyLibrary = 21 ) var ChannelBaseURLs = []string{ @@ -199,4 +200,5 @@ var ChannelBaseURLs = []string{ "", // 18 "https://ai.360.cn", // 19 "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 } diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go new file mode 100644 index 00000000..d0159ce8 --- /dev/null +++ b/controller/relay-aiproxy.go @@ -0,0 +1,220 @@ +package controller + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strconv" + "strings" +) + +// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 + +type AIProxyLibraryRequest struct { + Model string `json:"model"` + Query string `json:"query"` + LibraryId string `json:"libraryId"` + Stream bool `json:"stream"` +} + +type AIProxyLibraryError struct { + ErrCode int `json:"errCode"` + Message string `json:"message"` +} + +type AIProxyLibraryDocument struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type AIProxyLibraryResponse struct { + Success bool `json:"success"` + Answer string `json:"answer"` + Documents []AIProxyLibraryDocument `json:"documents"` + AIProxyLibraryError +} + +type AIProxyLibraryStreamResponse struct { + Content string `json:"content"` + Finish bool `json:"finish"` + Model string `json:"model"` + Documents []AIProxyLibraryDocument `json:"documents"` +} + +func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { + query := "" + if len(request.Messages) != 0 { + query = request.Messages[len(request.Messages)-1].Content + } + return &AIProxyLibraryRequest{ + Model: request.Model, + Stream: request.Stream, + Query: query, + } +} + +func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { + if len(documents) == 0 { + return "" + } + content := "\n\n参考文档:\n" + for i, document := range documents { + content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) + } + return content +} + +func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { + content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: content, + }, + FinishReason: "stop", + } + fullTextResponse := OpenAITextResponse{ + Id: common.GetUUID(), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + } + return &fullTextResponse +} + +func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = aiProxyDocuments2Markdown(documents) + choice.FinishReason = &stopFinishReason + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = response.Content + return &ChatCompletionsStreamResponse{ + Id: common.GetUUID(), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: response.Model, + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } +} + +func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + var documents []AIProxyLibraryDocument + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var AIProxyLibraryResponse AIProxyLibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + response := documentsAIProxyLibrary(documents) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var AIProxyLibraryResponse AIProxyLibraryResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if AIProxyLibraryResponse.ErrCode != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: AIProxyLibraryResponse.Message, + Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), + Code: AIProxyLibraryResponse.ErrCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 624b9d01..6f410f96 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -22,6 +22,7 @@ const ( APITypeZhipu APITypeAli APITypeXunfei + APITypeAIProxyLibrary ) var httpClient *http.Client @@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAli case common.ChannelTypeXunfei: apiType = APITypeXunfei + case common.ChannelTypeAIProxyLibrary: + apiType = APITypeAIProxyLibrary } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -171,6 +174,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + case APITypeAIProxyLibrary: + fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } var promptTokens int var completionTokens int @@ -263,6 +268,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeAIProxyLibrary: + aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) + aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + jsonStr, err := json.Marshal(aiProxyLibraryRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } var req *http.Request @@ -302,6 +315,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + default: + req.Header.Set("Authorization", "Bearer "+apiKey) } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) @@ -516,6 +531,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } else { return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) } + case APITypeAIProxyLibrary: + if isStream { + err, usage := aiProxyLibraryStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + err, usage := aiProxyLibraryHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/middleware/distributor.go b/middleware/distributor.go index 93827c95..e8b76596 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -115,8 +115,13 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.ModelMapping) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) - if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { + switch channel.Type { + case common.ChannelTypeAzure: c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + case common.ChannelTypeAIProxyLibrary: + c.Set("library_id", channel.Other) } c.Next() } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index b1631479..1eecfc5a 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [ { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index da11b588..7e150b4a 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -295,6 +295,20 @@ const EditChannel = () => { ) } + { + inputs.type === 21 && ( + + + + ) + } Date: Sun, 3 Sep 2023 14:58:20 +0800 Subject: [PATCH 12/21] feat: add batch update support (close #414) --- README.md | 4 +++ common/constants.go | 3 ++ main.go | 5 +++ model/channel.go | 8 +++++ model/token.go | 16 ++++++++++ model/user.go | 26 +++++++++++++++- model/utils.go | 75 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 model/utils.go diff --git a/README.md b/README.md index a0f3bcb9..a2105df2 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,10 @@ graph LR + 例子:`CHANNEL_TEST_FREQUENCY=1440` 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` +10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`BATCH_UPDATE_ENABLED=true` +11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + + 例子:`BATCH_UPDATE_INTERVAL=5` ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 66ca06f4..b272fbe6 100644 --- a/common/constants.go +++ b/common/constants.go @@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY +var BatchUpdateEnabled = false +var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/main.go b/main.go index 9fb0a73e..8c5f2f31 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,11 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { + common.BatchUpdateEnabled = true + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + model.InitBatchUpdater() + } controller.InitTokenEncoders() // Initialize HTTP server diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..5c495bab 100644 --- a/model/channel.go +++ b/model/channel.go @@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) + return + } + updateChannelUsedQuota(id, quota) +} + +func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/token.go b/model/token.go index 7cd226c6..dfda27e3 100644 --- a/model/token.go +++ b/model/token.go @@ -131,6 +131,14 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + return nil + } + return increaseTokenQuota(id, quota) +} + +func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), @@ -144,6 +152,14 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) + return nil + } + return decreaseTokenQuota(id, quota) +} + +func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), diff --git a/model/user.go b/model/user.go index 7c771840..67511267 100644 --- a/model/user.go +++ b/model/user.go @@ -275,6 +275,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, quota) + return nil + } + return increaseUserQuota(id, quota) +} + +func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } @@ -283,6 +291,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, -quota) + return nil + } + return decreaseUserQuota(id, quota) +} + +func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -293,10 +309,18 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) + return + } + updateUserUsedQuotaAndRequestCount(id, quota, 1) +} + +func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), - "request_count": gorm.Expr("request_count + ?", 1), + "request_count": gorm.Expr("request_count + ?", count), }, ).Error if err != nil { diff --git a/model/utils.go b/model/utils.go new file mode 100644 index 00000000..61734332 --- /dev/null +++ b/model/utils.go @@ -0,0 +1,75 @@ +package model + +import ( + "one-api/common" + "sync" + "time" +) + +const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock + +const ( + BatchUpdateTypeUserQuota = iota + BatchUpdateTypeTokenQuota + BatchUpdateTypeUsedQuotaAndRequestCount + BatchUpdateTypeChannelUsedQuota +) + +var batchUpdateStores []map[int]int +var batchUpdateLocks []sync.Mutex + +func init() { + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) + } +} + +func InitBatchUpdater() { + go func() { + for { + time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + batchUpdate() + } + }() +} + +func addNewRecord(type_ int, id int, value int) { + batchUpdateLocks[type_].Lock() + defer batchUpdateLocks[type_].Unlock() + if _, ok := batchUpdateStores[type_][id]; !ok { + batchUpdateStores[type_][id] = value + } else { + batchUpdateStores[type_][id] += value + } +} + +func batchUpdate() { + common.SysLog("batch update started") + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateLocks[i].Lock() + store := batchUpdateStores[i] + batchUpdateStores[i] = make(map[int]int) + batchUpdateLocks[i].Unlock() + + for key, value := range store { + switch i { + case BatchUpdateTypeUserQuota: + err := increaseUserQuota(key, value) + if err != nil { + common.SysError("failed to batch update user quota: " + err.Error()) + } + case BatchUpdateTypeTokenQuota: + err := increaseTokenQuota(key, value) + if err != nil { + common.SysError("failed to batch update token quota: " + err.Error()) + } + case BatchUpdateTypeUsedQuotaAndRequestCount: + updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect + case BatchUpdateTypeChannelUsedQuota: + updateChannelUsedQuota(key, value) + } + } + } + common.SysLog("batch update finished") +} From 9db93316c4bd86c66fb74d5e7c36eaf4f3fde697 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 15:12:54 +0800 Subject: [PATCH 13/21] docs: update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a2105df2..382dfb35 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,7 @@ graph LR + 例子:`POLLING_INTERVAL=5` 10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` From 7e575abb951475c95adfe32d82994b81d50a7094 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 15:50:49 +0800 Subject: [PATCH 14/21] feat: add channel type FastGPT --- common/constants.go | 46 ++++++++++++++------------ web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 32 ++++++++++++++++-- 3 files changed, 55 insertions(+), 24 deletions(-) diff --git a/common/constants.go b/common/constants.go index b272fbe6..4a3f3f2b 100644 --- a/common/constants.go +++ b/common/constants.go @@ -179,29 +179,31 @@ const ( ChannelType360 = 19 ChannelTypeOpenRouter = 20 ChannelTypeAIProxyLibrary = 21 + ChannelTypeFastGPT = 22 ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 1eecfc5a..e42afc6e 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [ { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 7e150b4a..d75e67eb 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = { 'gpt-4-32k-0314': 'gpt-4-32k' }; +function type2secretPrompt(type) { + // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') + switch (type) { + case 15: + return "按照如下格式输入:APIKey|SecretKey" + case 18: + return "按照如下格式输入:APPID|APISecret|APIKey" + case 22: + return "按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041" + default: + return "请输入渠道对应的鉴权密钥" + } +} + const EditChannel = () => { const params = useParams(); const navigate = useNavigate(); @@ -389,7 +403,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={type2secretPrompt(inputs.type)} onChange={handleInputChange} value={inputs.key} autoComplete='new-password' @@ -407,7 +421,7 @@ const EditChannel = () => { ) } { - inputs.type !== 3 && inputs.type !== 8 && ( + inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( { ) } + { + inputs.type === 22 && ( + + + + ) + } From 621eb91b46a1909bf439cc50688bcdbc689e28b2 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:31:58 +0800 Subject: [PATCH 15/21] chore: pass through error out --- controller/relay-text.go | 1 - middleware/auth.go | 13 +++++++++++- model/cache.go | 29 ++++++++++++++++----------- model/token.go | 43 ++++++++++++++++++++++------------------ model/user.go | 9 ++++----- 5 files changed, 57 insertions(+), 38 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index 6f410f96..c6659799 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -377,7 +377,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) } } diff --git a/middleware/auth.go b/middleware/auth.go index 060e005c..95516d6e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) { c.Abort() return } - if !model.CacheIsUserEnabled(token.UserId) { + userEnabled, err := model.IsUserEnabled(token.UserId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if !userEnabled { c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "message": "用户已被封禁", diff --git a/model/cache.go b/model/cache.go index 55fbba9b..c28952b5 100644 --- a/model/cache.go +++ b/model/cache.go @@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error { return err } -func CacheIsUserEnabled(userId int) bool { +func CacheIsUserEnabled(userId int) (bool, error) { if !common.RedisEnabled { return IsUserEnabled(userId) } enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) - if err != nil { - status := common.UserStatusDisabled - if IsUserEnabled(userId) { - status = common.UserStatusEnabled - } - enabled = fmt.Sprintf("%d", status) - err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) - } + if err == nil { + return enabled == "1", nil } - return enabled == "1" + + userEnabled, err := IsUserEnabled(userId) + if err != nil { + return false, err + } + enabled = "0" + if userEnabled { + enabled = "1" + } + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user enabled error: " + err.Error()) + } + return userEnabled, err } var group2model2channels map[string]map[string][]*Channel diff --git a/model/token.go b/model/token.go index dfda27e3..0fa984d3 100644 --- a/model/token.go +++ b/model/token.go @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err == nil { + if token.Status == common.TokenStatusExhausted { + return nil, errors.New("该令牌额度已用尽") + } else if token.Status == common.TokenStatusExpired { + return nil, errors.New("该令牌已过期") + } if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { - token.Status = common.TokenStatusExpired - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + token.Status = common.TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { - token.Status = common.TokenStatusExhausted - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + // in this case, we can make sure the token is exhausted + token.Status = common.TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌额度已用尽") } - go func() { - token.AccessedTime = common.GetTimestamp() - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token" + err.Error()) - } - }() return token, nil } return nil, errors.New("无效的令牌") @@ -141,8 +144,9 @@ func IncreaseTokenQuota(id int, quota int) (err error) { func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota + ?", quota), - "used_quota": gorm.Expr("used_quota - ?", quota), + "remain_quota": gorm.Expr("remain_quota + ?", quota), + "used_quota": gorm.Expr("used_quota - ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err @@ -162,8 +166,9 @@ func DecreaseTokenQuota(id int, quota int) (err error) { func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota - ?", quota), - "used_quota": gorm.Expr("used_quota + ?", quota), + "remain_quota": gorm.Expr("remain_quota - ?", quota), + "used_quota": gorm.Expr("used_quota + ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err diff --git a/model/user.go b/model/user.go index 67511267..cee4b023 100644 --- a/model/user.go +++ b/model/user.go @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { return user.Role >= common.RoleAdminUser } -func IsUserEnabled(userId int) bool { +func IsUserEnabled(userId int) (bool, error) { if userId == 0 { - return false + return false, errors.New("user id is empty") } var user User err := DB.Where("id = ?", userId).Select("status").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) - return false + return false, err } - return user.Status == common.UserStatusEnabled + return user.Status == common.UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { From 276163affdef85d5adf9044fdbe0ab1b32455fcd Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:40:58 +0800 Subject: [PATCH 16/21] fix: press enter to submit custom model name --- web/src/pages/Channel/EditChannel.js | 50 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index d75e67eb..2ad3dc6a 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -14,13 +14,13 @@ function type2secretPrompt(type) { // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') switch (type) { case 15: - return "按照如下格式输入:APIKey|SecretKey" + return '按照如下格式输入:APIKey|SecretKey'; case 18: - return "按照如下格式输入:APPID|APISecret|APIKey" + return '按照如下格式输入:APPID|APISecret|APIKey'; case 22: - return "按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041" + return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; default: - return "请输入渠道对应的鉴权密钥" + return '请输入渠道对应的鉴权密钥'; } } @@ -207,6 +207,24 @@ const EditChannel = () => { } }; + const addCustomModel = () => { + if (customModel.trim() === '') return; + if (inputs.models.includes(customModel)) return; + let localModels = [...inputs.models]; + localModels.push(customModel); + let localModelOptions = []; + localModelOptions.push({ + key: customModel, + text: customModel, + value: customModel + }); + setModelOptions(modelOptions => { + return [...modelOptions, ...localModelOptions]; + }); + setCustomModel(''); + handleInputChange(null, { name: 'models', value: localModels }); + }; + return ( <> @@ -350,29 +368,19 @@ const EditChannel = () => { }}>清除所有模型 { - if (customModel.trim() === '') return; - if (inputs.models.includes(customModel)) return; - let localModels = [...inputs.models]; - localModels.push(customModel); - let localModelOptions = []; - localModelOptions.push({ - key: customModel, - text: customModel, - value: customModel - }); - setModelOptions(modelOptions => { - return [...modelOptions, ...localModelOptions]; - }); - setCustomModel(''); - handleInputChange(null, { name: 'models', value: localModels }); - }}>填入 + } placeholder='输入自定义模型名称' value={customModel} onChange={(e, { value }) => { setCustomModel(value); }} + onKeyDown={(e) => { + if (e.key === 'Enter') { + addCustomModel(); + e.preventDefault(); + } + }} /> From a721a5b6f9d4a9ce721ce323c5eeb0be34a3b97b Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:46:07 +0800 Subject: [PATCH 17/21] chore: add error prompt for Azure --- controller/channel-test.go | 7 ++++++- i18n/en.json | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 686521ef..8c7e6f0d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,7 +14,7 @@ import ( "time" ) -func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { +func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { switch channel.Type { case common.ChannelTypePaLM: fallthrough @@ -32,6 +32,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: request.Model = "gpt-35-turbo" + defer func() { + if err != nil { + err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") + } + }() default: request.Model = "gpt-3.5-turbo" } diff --git a/i18n/en.json b/i18n/en.json index aed65979..9b2ca4c8 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -523,5 +523,6 @@ "按照如下格式输入:": "Enter in the following format:", "模型版本": "Model version", "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", - "点击查看": "click to view" + "点击查看": "click to view", + "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" } From 0f949c3782f302b41b07475696fdf63d6cdb99ff Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:49:41 +0800 Subject: [PATCH 18/21] docs: update README (close #482) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 382dfb35..92eb7f6a 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 -如果启动失败,请添加 `--privileged=true`,具体参考 #482。 +如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482。 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 From c55bb6781824f384f940e7daf1b94cd2f5fe3ff9 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:50:00 +0800 Subject: [PATCH 19/21] docs: update README (close #482) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 92eb7f6a..4f25d3f5 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 -如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482。 +如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 From bd6fe1e93cf96d79d3a6f7872a6df1c367872ed6 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 21:56:37 +0800 Subject: [PATCH 20/21] feat: able to config rate limit (close #477) --- README.md | 3 +++ common/constants.go | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4f25d3f5..b89c6be8 100644 --- a/README.md +++ b/README.md @@ -311,6 +311,9 @@ graph LR + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` +12. 请求频率限制: + + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 4a3f3f2b..69bd12a8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -114,10 +114,10 @@ var ( // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = 180 + GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = 60 + GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 From d0a0e871e117591fdf018e619c5d71ee371331f6 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 3 Sep 2023 22:12:35 +0800 Subject: [PATCH 21/21] fix: support ali's embedding model (#481, close #469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat:支持阿里的 embedding 模型 * fix: add to model list --------- Co-authored-by: JustSong Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com> --- common/model-ratio.go | 7 ++- controller/model.go | 9 +++ controller/relay-ali.go | 88 ++++++++++++++++++++++++++++ controller/relay-baidu.go | 15 +---- controller/relay-text.go | 24 +++++++- controller/relay.go | 19 ++++++ web/src/pages/Channel/EditChannel.js | 2 +- 7 files changed, 144 insertions(+), 20 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 70758805..eeb23e07 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag - "qwen-plus-v1": 0.5715, // Same as above - "SparkDesk": 0.8572, // TBD + "qwen-v1": 0.8572, // ¥0.012 / 1k tokens + "qwen-plus-v1": 1, // ¥0.014 / 1k tokens + "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 88f95f7b..637ebe10 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,15 @@ func init() { Root: "qwen-plus-v1", Parent: nil, }, + { + Id: "text-embedding-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "text-embedding-v1", + Parent: nil, + }, { Id: "SparkDesk", Object: "model", diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 9dca9a89..50dc743c 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -35,6 +35,29 @@ type AliChatRequest struct { Parameters AliParameters `json:"parameters,omitempty"` } +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + type AliError struct { Code string `json:"code"` Message string `json:"message"` @@ -44,6 +67,7 @@ type AliError struct { type AliUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` } type AliOutput struct { @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { } } +func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { + return &AliEmbeddingRequest{ + Model: "text-embedding-v1", + Input: struct { + Texts []string `json:"texts"` + }{ + Texts: request.ParseInput(), + }, + } +} + +func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var aliResponse AliEmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&aliResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Code != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Model: "text-embedding-v1", + Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + } + + for _, item := range response.Output.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { choice := OpenAITextResponseChoice{ Index: 0, diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 39f31a9a..ed08ac04 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom } func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - baiduEmbeddingRequest := BaiduEmbeddingRequest{ - Input: nil, + return &BaiduEmbeddingRequest{ + Input: request.ParseInput(), } - switch request.Input.(type) { - case string: - baiduEmbeddingRequest.Input = []string{request.Input.(string)} - case []any: - for _, item := range request.Input.([]any) { - if str, ok := item.(string); ok { - baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) - } - } - } - return &baiduEmbeddingRequest } func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { diff --git a/controller/relay-text.go b/controller/relay-text.go index c6659799..b190a999 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -174,6 +174,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } case APITypeAIProxyLibrary: fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } @@ -262,8 +265,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeAli: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err := json.Marshal(aliRequest) + var jsonStr []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliEmbeddingRequest) + default: + aliRequest := requestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliRequest) + } if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } @@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := aliHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + err, usage = aliHandler(c, resp) + } if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 056d42d3..d20663f6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct { Functions any `json:"functions,omitempty"` } +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 2ad3dc6a..78ff1952 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -67,7 +67,7 @@ const EditChannel = () => { localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; break; case 17: - localModels = ['qwen-v1', 'qwen-plus-v1']; + localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1']; break; case 16: localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];