From 2ba28c72cbd41ed1020d71d3fd57367ea99be7fd Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 30 Mar 2024 10:43:26 +0800 Subject: [PATCH 01/35] feat: support function call for ali (close #1242) --- common/conv/any.go | 6 ++++++ relay/channel/ali/main.go | 28 ++++++++++++++++------------ relay/channel/ali/model.go | 18 +++++++++++++----- relay/channel/openai/main.go | 3 ++- relay/channel/openai/model.go | 9 +++------ relay/channel/tencent/main.go | 3 ++- relay/model/general.go | 26 ++++++++++++++------------ relay/model/message.go | 7 ++++--- relay/model/tool.go | 14 ++++++++++++++ 9 files changed, 74 insertions(+), 40 deletions(-) create mode 100644 common/conv/any.go create mode 100644 relay/model/tool.go diff --git a/common/conv/any.go b/common/conv/any.go new file mode 100644 index 00000000..467e8bb7 --- /dev/null +++ b/common/conv/any.go @@ -0,0 +1,6 @@ +package conv + +func AsString(v any) string { + str, _ := v.(string) + return str +} diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index 62115d58..6fdfa4d4 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -48,7 +48,10 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { MaxTokens: request.MaxTokens, Temperature: request.Temperature, TopP: request.TopP, + TopK: request.TopK, + ResultFormat: "message", }, + Tools: request.Tools, } } @@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR } func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: response.Output.Text, - }, - FinishReason: response.Output.FinishReason, - } fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, + Choices: response.Output.Choices, Usage: model.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, @@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { } func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + if len(aliResponse.Output.Choices) == 0 { + return nil + } + aliChoice := aliResponse.Output.Choices[0] var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = aliResponse.Output.Text - if aliResponse.Output.FinishReason != "null" { - finishReason := aliResponse.Output.FinishReason + choice.Delta = aliChoice.Message + if aliChoice.FinishReason != "null" { + finishReason := aliChoice.FinishReason choice.FinishReason = &finishReason } response := openai.ChatCompletionsStreamResponse{ @@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + return true + } //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) @@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := c.Request.Context() var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + logger.Debugf(ctx, "response body: %s\n", responseBody) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index 76e814d1..e19d427a 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -1,5 +1,10 @@ package ali +import ( + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" +) + type Message struct { Content string `json:"content"` Role string `json:"role"` @@ -18,12 +23,14 @@ type Parameters struct { IncrementalOutput bool `json:"incremental_output,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` } type ChatRequest struct { - Model string `json:"model"` - Input Input `json:"input"` - Parameters Parameters `json:"parameters,omitempty"` + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` } type EmbeddingRequest struct { @@ -62,8 +69,9 @@ type Usage struct { } type Output struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` + //Text string `json:"text"` + //FinishReason string `json:"finish_reason"` + Choices []openai.TextResponseChoice `json:"choices"` } type ChatResponse struct { diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index d47cd164..63cb9ae8 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E continue // just ignore the error } for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content + responseText += conv.AsString(choice.Delta.Content) } if streamResponse.Usage != nil { usage = streamResponse.Usage diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index 6c0b2c53..30d77739 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -118,12 +118,9 @@ type ImageResponse struct { } type ChatCompletionsStreamResponseChoice struct { - Index int `json:"index"` - Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index"` + Delta model.Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index cfdc0bfd..b5a64cde 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" @@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } response := streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content + responseText += conv.AsString(response.Choices[0].Delta.Content) } jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/model/general.go b/relay/model/general.go index fbcc04e8..86facf04 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -5,25 +5,27 @@ type ResponseFormat struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` User string `json:"user,omitempty"` + Prompt any `json:"prompt,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/message.go b/relay/model/message.go index c6c8a271..32a1055b 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,9 +1,10 @@ package model type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go new file mode 100644 index 00000000..253dca35 --- /dev/null +++ b/relay/model/tool.go @@ -0,0 +1,14 @@ +package model + +type Tool struct { + Id string `json:"id,omitempty"` + Type string `json:"type"` + Function Function `json:"function"` +} + +type Function struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters any `json:"parameters,omitempty"` // request + Arguments any `json:"arguments,omitempty"` // response +} From 3f3c13c98c3ba10d5ca674c6c688f75fe9148c3e Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 30 Mar 2024 10:43:26 +0800 Subject: [PATCH 02/35] feat: support top_k for claude (close #1239) --- relay/channel/anthropic/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go index 3eeb0b2c..04e65d99 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/channel/anthropic/main.go @@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { MaxTokens: textRequest.MaxTokens, Temperature: textRequest.Temperature, TopP: textRequest.TopP, + TopK: textRequest.TopK, Stream: textRequest.Stream, } if claudeRequest.MaxTokens == 0 { From a9c464ec5ad6861d10cbd0f4bcd9859f7fd8abdd Mon Sep 17 00:00:00 2001 From: ManJieqi <40858189+ManJieqi@users.noreply.github.com> Date: Sat, 30 Mar 2024 11:06:31 +0800 Subject: [PATCH 03/35] =?UTF-8?q?fix:=20update=20model-ratio.go=20?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E6=96=87=E5=BF=83=E8=AE=A1=E8=B4=B9=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 统一文心计费模型名称 --- common/model-ratio.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 460d4843..aa75042e 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -75,7 +75,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8k": 0.024 * RMB, + "ERNIE-Bot-8K": 0.024 * RMB, "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "bge-large-zh": 0.002 * RMB, "bge-large-en": 0.002 * RMB, From 06a3fc54216cf1a8229193f749a9d316db5bfb9e Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 31 Mar 2024 22:23:42 +0800 Subject: [PATCH 04/35] chore: update GeneralOpenAIRequest --- relay/model/general.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/relay/model/general.go b/relay/model/general.go index 86facf04..30772894 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -24,6 +24,8 @@ type GeneralOpenAIRequest struct { User string `json:"user,omitempty"` Prompt any `json:"prompt,omitempty"` Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` Instruction string `json:"instruction,omitempty"` Size string `json:"size,omitempty"` } From f89ae5ad5830ab5962b944b5becedecd60de6e3c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 31 Mar 2024 23:12:29 +0800 Subject: [PATCH 05/35] feat: initial function call support for xunfei --- relay/channel/xunfei/main.go | 34 ++++++++++++++++++++++++++++++++-- relay/channel/xunfei/model.go | 11 ++++++++--- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index 5e7014cb..67784a56 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -26,7 +26,11 @@ import ( func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) + var lastToolCalls []model.Tool for _, message := range request.Messages { + if message.ToolCalls != nil { + lastToolCalls = message.ToolCalls + } messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), @@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages + if len(lastToolCalls) != 0 { + for _, toolCall := range lastToolCalls { + xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) + } + } + return &xunfeiRequest } +func getToolCalls(response *ChatResponse) []model.Tool { + var toolCalls []model.Tool + if len(response.Payload.Choices.Text) == 0 { + return toolCalls + } + item := response.Payload.Choices.Text[0] + if item.FunctionCall == nil { + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", helper.GetUUID()), + Type: "function", + Function: *item.FunctionCall, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Payload.Choices.Text) == 0 { response.Payload.Choices.Text = []ChatResponseTextItem{ @@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ - Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + ToolCalls: getToolCalls(response), }, FinishReason: constant.StopFinishReason, } @@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) if xunfeiResponse.Payload.Choices.Status == 2 { choice.FinishReason = &constant.StopFinishReason } diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go index 1266739d..e9cc59a6 100644 --- a/relay/channel/xunfei/model.go +++ b/relay/channel/xunfei/model.go @@ -26,13 +26,18 @@ type ChatRequest struct { Message struct { Text []Message `json:"text"` } `json:"message"` + Functions struct { + Text []model.Function `json:"text,omitempty"` + } `json:"functions"` } `json:"payload"` } type ChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` + ContentType string `json:"content_type"` + FunctionCall *model.Function `json:"function_call"` } type ChatResponse struct { From e3cfb1fa524107439e4b0caec0137e249bd06467 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 31 Mar 2024 23:41:52 +0800 Subject: [PATCH 06/35] feat: use given usage if available in stream mode --- relay/channel/openai/adaptor.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 1f153c3e..9be695f2 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -70,8 +70,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string - err, responseText, _ = StreamHandler(c, resp, meta.Mode) - usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + err, responseText, usage = StreamHandler(c, resp, meta.Mode) + if usage == nil { + usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } } else { err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } From 065da8ef8c8bcbc0a7fc3ae22e39397ffe036b6a Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 00:46:30 +0800 Subject: [PATCH 07/35] fix: fix ali function call (#1242) --- relay/channel/ali/main.go | 2 +- relay/channel/ali/model.go | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index 6fdfa4d4..dd1707ee 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -50,8 +50,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { TopP: request.TopP, TopK: request.TopK, ResultFormat: "message", + Tools: request.Tools, }, - Tools: request.Tools, } } diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index e19d427a..3b8a8372 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -16,21 +16,21 @@ type Input struct { } type Parameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - ResultFormat string `json:"result_format,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` } type ChatRequest struct { - Model string `json:"model"` - Input Input `json:"input"` - Parameters Parameters `json:"parameters,omitempty"` - Tools []model.Tool `json:"tools,omitempty"` + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` } type EmbeddingRequest struct { From dc7aaf2de5aaf0000073a7466acbda6fe213c291 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 02:08:18 +0800 Subject: [PATCH 08/35] feat: able to set model limitation for token (close #178) --- controller/model.go | 28 ++++++++++++ controller/token.go | 1 + middleware/auth.go | 13 ++++++ middleware/distributor.go | 38 +++------------- middleware/utils.go | 42 ++++++++++++++++++ model/cache.go | 20 +++++++++ model/channel.go | 30 ++++++++++++- model/token.go | 29 ++++++------ router/api-router.go | 1 + web/default/src/pages/Token/EditToken.js | 56 +++++++++++++++++++++--- 10 files changed, 204 insertions(+), 54 deletions(-) diff --git a/controller/model.go b/controller/model.go index 4c5476b4..bf4b83a7 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -142,3 +143,30 @@ func RetrieveModel(c *gin.Context) { }) } } + +func GetUserAvailableModels(c *gin.Context) { + ctx := c.Request.Context() + id := c.GetInt("id") + userGroup, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + models, err := model.CacheGetGroupModels(ctx, userGroup) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) + return +} diff --git a/controller/token.go b/controller/token.go index 949931da..c6128534 100644 --- a/controller/token.go +++ b/controller/token.go @@ -216,6 +216,7 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models } err = cleanToken.Update() if err != nil { diff --git a/middleware/auth.go b/middleware/auth.go index 30997efd..443199d0 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" @@ -107,6 +108,18 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } + requestModel, err := getRequestModel(c) + if err != nil { + abortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + c.Set("request_model", requestModel) + if token.Models != nil && *token.Models != "" { + if !isModelInList(requestModel, *token.Models) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + return + } + } c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) diff --git a/middleware/distributor.go b/middleware/distributor.go index e845c2f8..04489a2b 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,14 +2,12 @@ package middleware import ( "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "net/http" "strconv" - "strings" - - "github.com/gin-gonic/gin" ) type ModelRequest struct { @@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) { return } } else { - // Select a channel for the user - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + requestModel := c.GetString("request_model") + var err error + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求") - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e-2" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" - } - } - requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) if channel != nil { logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" diff --git a/middleware/utils.go b/middleware/utils.go index bc14c367..b65b018b 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,9 +1,12 @@ package middleware import ( + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() logger.Error(c.Request.Context(), message) } + +func getRequestModel(c *gin.Context) (string, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e-2" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + return modelRequest.Model, nil +} + +func isModelInList(modelName string, models string) bool { + modelList := strings.Split(models, ",") + for _, model := range modelList { + if modelName == model { + return true + } + } + return false +} diff --git a/model/cache.go b/model/cache.go index 244fe6ac..cfc5445a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -21,6 +21,7 @@ var ( UserId2GroupCacheSeconds = config.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency + GroupModelsCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } +func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { + if !common.RedisEnabled { + return GetGroupModels(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) + if err == nil { + return strings.Split(modelsStr, ","), nil + } + models, err := GetGroupModels(ctx, group) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex diff --git a/model/channel.go b/model/channel.go index fc4905b1..24829bc5 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,6 +1,7 @@ package model import ( + "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common" @@ -8,6 +9,8 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" + "sort" + "strings" ) type Channel struct { @@ -25,7 +28,7 @@ type Channel struct { Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` - Group string `json:"group" gorm:"type:varchar(32);default:'default'"` + Group string `json:"group" gorm:"index;type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` @@ -202,3 +205,28 @@ func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + if common.UsingPostgreSQL { + groupCol = `"group"` + } + var modelsList []string + err := DB.Model(&Channel{}).Distinct("models").Where(groupCol+" = ?", group).Pluck("models", &modelsList).Error + if err != nil { + return nil, err + } + set := make(map[string]bool) + for i := 0; i < len(modelsList); i++ { + modelList := strings.Split(modelsList[i], ",") + for _, model := range modelList { + set[model] = true + } + } + modelList := make([]string, 0, len(set)) + for model := range set { + modelList = append(modelList, model) + } + sort.Strings(modelList) + return modelList, err +} diff --git a/model/token.go b/model/token.go index 493e27c9..fef80fcf 100644 --- a/model/token.go +++ b/model/token.go @@ -12,24 +12,25 @@ import ( ) type Token struct { - Id int `json:"id"` - UserId int `json:"user_id"` - Key string `json:"key" gorm:"type:char(48);uniqueIndex"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index" ` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - AccessedTime int64 `json:"accessed_time" gorm:"bigint"` - ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` - UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Models *string `json:"models" gorm:"default:''"` } func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { var tokens []*Token var err error query := DB.Where("user_id = ?", userId) - + switch order { case "remain_quota": query = query.Order("unlimited_quota desc, remain_quota desc") @@ -38,7 +39,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token default: query = query.Order("id desc") } - + err = query.Limit(num).Offset(startIdx).Find(&tokens).Error return tokens, err } @@ -121,7 +122,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error return err } diff --git a/router/api-router.go b/router/api-router.go index 5b755ede..4aa6d830 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -43,6 +43,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/available_models", controller.GetUserAvailableModels) } adminRoute := userRoute.Group("/") diff --git a/web/default/src/pages/Token/EditToken.js b/web/default/src/pages/Token/EditToken.js index 0ab37c29..6bc3ad23 100644 --- a/web/default/src/pages/Token/EditToken.js +++ b/web/default/src/pages/Token/EditToken.js @@ -1,19 +1,21 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; -import { useParams, useNavigate } from 'react-router-dom'; -import { API, showError, showSuccess, timestamp2string } from '../../helpers'; -import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; +import { useNavigate, useParams } from 'react-router-dom'; +import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers'; +import { renderQuotaWithPrompt } from '../../helpers/render'; const EditToken = () => { const params = useParams(); const tokenId = params.id; const isEdit = tokenId !== undefined; const [loading, setLoading] = useState(isEdit); + const [modelOptions, setModelOptions] = useState([]); const originInputs = { name: '', remain_quota: isEdit ? 0 : 500000, expired_time: -1, - unlimited_quota: false + unlimited_quota: false, + models: [] }; const [inputs, setInputs] = useState(originInputs); const { name, remain_quota, expired_time, unlimited_quota } = inputs; @@ -22,8 +24,8 @@ const EditToken = () => { setInputs((inputs) => ({ ...inputs, [name]: value })); }; const handleCancel = () => { - navigate("/token"); - } + navigate('/token'); + }; const setExpiredTime = (month, day, hour, minute) => { let now = new Date(); let timestamp = now.getTime() / 1000; @@ -50,6 +52,11 @@ const EditToken = () => { if (data.expired_time !== -1) { data.expired_time = timestamp2string(data.expired_time); } + if (data.models === '') { + data.models = []; + } else { + data.models = data.models.split(','); + } setInputs(data); } else { showError(message); @@ -60,8 +67,26 @@ const EditToken = () => { if (isEdit) { loadToken().then(); } + loadAvailableModels().then(); }, []); + const loadAvailableModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + let options = data.map((model) => { + return { + key: model, + text: model, + value: model + }; + }); + setModelOptions(options); + } else { + showError(message); + } + }; + const submit = async () => { if (!isEdit && inputs.name === '') return; let localInputs = inputs; @@ -74,6 +99,7 @@ const EditToken = () => { } localInputs.expired_time = Math.ceil(time / 1000); } + localInputs.models = localInputs.models.join(','); let res; if (isEdit) { res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); @@ -109,6 +135,24 @@ const EditToken = () => { required={!isEdit} /> + + { + copy(value).then(); + }} + selection + onChange={handleInputChange} + value={inputs.models} + autoComplete='new-password' + options={modelOptions} + /> + Date: Thu, 4 Apr 2024 02:44:59 +0800 Subject: [PATCH 09/35] feat: /v1/models now only return available models --- controller/model.go | 35 ++++++++++++++++++++++++++++++++++- middleware/auth.go | 3 ++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/controller/model.go b/controller/model.go index bf4b83a7..53649391 100644 --- a/controller/model.go +++ b/controller/model.go @@ -11,6 +11,7 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "net/http" + "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -121,9 +122,41 @@ func DashboardListModels(c *gin.Context) { } func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString("available_models") != "" { + availableModels = strings.Split(c.GetString("available_models"), ",") + } else { + userId := c.GetInt("id") + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + var availableOpenAIModels []OpenAIModels + for _, model := range openAIModels { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } c.JSON(200, gin.H{ "object": "list", - "data": openAIModels, + "data": availableOpenAIModels, }) } diff --git a/middleware/auth.go b/middleware/auth.go index 443199d0..29701524 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -115,7 +115,8 @@ func TokenAuth() func(c *gin.Context) { } c.Set("request_model", requestModel) if token.Models != nil && *token.Models != "" { - if !isModelInList(requestModel, *token.Models) { + c.Set("available_models", *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) return } From 8b9fa3d6e452fbc95bfc37db836c69ed39f3f094 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 02:58:21 +0800 Subject: [PATCH 10/35] fix: fix GetGroupModels --- model/ability.go | 18 ++++++++++++++++++ model/channel.go | 30 +----------------------------- 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/model/ability.go b/model/ability.go index 48b856a2..4a48bc51 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,8 +1,10 @@ package model import ( + "context" "github.com/songquanpeng/one-api/common" "gorm.io/gorm" + "sort" "strings" ) @@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/channel.go b/model/channel.go index 24829bc5..fc4905b1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,7 +1,6 @@ package model import ( - "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common" @@ -9,8 +8,6 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" - "sort" - "strings" ) type Channel struct { @@ -28,7 +25,7 @@ type Channel struct { Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` - Group string `json:"group" gorm:"index;type:varchar(32);default:'default'"` + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` @@ -205,28 +202,3 @@ func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } - -func GetGroupModels(ctx context.Context, group string) ([]string, error) { - groupCol := "`group`" - if common.UsingPostgreSQL { - groupCol = `"group"` - } - var modelsList []string - err := DB.Model(&Channel{}).Distinct("models").Where(groupCol+" = ?", group).Pluck("models", &modelsList).Error - if err != nil { - return nil, err - } - set := make(map[string]bool) - for i := 0; i < len(modelsList); i++ { - modelList := strings.Split(modelsList[i], ",") - for _, model := range modelList { - set[model] = true - } - } - modelList := make([]string, 0, len(set)) - for model := range set { - modelList = append(modelList, model) - } - sort.Strings(modelList) - return modelList, err -} From ed70881a58bc77c9be86122d95612c2d225be633 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 11:18:21 +0800 Subject: [PATCH 11/35] fix: fix token create --- controller/token.go | 1 + 1 file changed, 1 insertion(+) diff --git a/controller/token.go b/controller/token.go index c6128534..13b90de0 100644 --- a/controller/token.go +++ b/controller/token.go @@ -130,6 +130,7 @@ func AddToken(c *gin.Context) { ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, + Models: token.Models, } err = cleanToken.Insert() if err != nil { From fb90747c23373bce14e92abb823f08387c005aea Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 18:53:42 +0800 Subject: [PATCH 12/35] fix: fix /v1/models return null data when no models available --- controller/model.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/model.go b/controller/model.go index 53649391..43e73c6c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -135,7 +135,7 @@ func ListModels(c *gin.Context) { for _, availableModel := range availableModels { modelSet[availableModel] = true } - var availableOpenAIModels []OpenAIModels + availableOpenAIModels := make([]OpenAIModels, 0) for _, model := range openAIModels { if _, ok := modelSet[model.Id]; ok { modelSet[model.Id] = false From 6f036bd0c937afc9e477d421dd8c3113424f313b Mon Sep 17 00:00:00 2001 From: Yang Fei Date: Thu, 4 Apr 2024 23:32:59 +0800 Subject: [PATCH 13/35] feat: add embedding-2 support for zhipu (#1273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加对智谱embedding-2模型的支持 * fix: fix usage & ratio --------- Co-authored-by: yangfei Co-authored-by: JustSong --- common/model-ratio.go | 1 + relay/channel/zhipu/adaptor.go | 44 ++++++++++++++++++++++-------- relay/channel/zhipu/constants.go | 2 +- relay/channel/zhipu/main.go | 47 ++++++++++++++++++++++++++++++++ relay/channel/zhipu/model.go | 18 ++++++++++++ 5 files changed, 100 insertions(+), 12 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index aa75042e..d8356dc2 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -91,6 +91,7 @@ var ModelRatio = map[string]float64{ "glm-4": 0.1 * RMB, "glm-4v": 0.1 * RMB, "glm-3-turbo": 0.005 * RMB, + "embedding-2": 0.0005 * RMB, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 0ca23d59..7b570e71 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { if a.APIVersion == "v4" { return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil } + if meta.Mode == constant.RelayModeEmbeddings { + return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil + } method := "invoke" if meta.IsStream { method = "sse-invoke" @@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + // TopP (0.0, 1.0) + request.TopP = math.Min(0.99, request.TopP) + request.TopP = math.Max(0.01, request.TopP) - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) - a.SetVersionByModeName(request.Model) - if a.APIVersion == "v4" { - return request, nil + // Temperature (0.0, 1.0) + request.Temperature = math.Min(0.99, request.Temperature) + request.Temperature = math.Max(0.01, request.Temperature) + a.SetVersionByModeName(request.Model) + if a.APIVersion == "v4" { + return request, nil + } + return ConvertRequest(*request), nil } - return ConvertRequest(*request), nil } func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { @@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel if a.APIVersion == "v4" { return a.DoResponseV4(c, resp, meta) } + if meta.IsStream { err, usage = StreamHandler(c, resp) } else { - err, usage = Handler(c, resp) + if meta.Mode == constant.RelayModeEmbeddings { + err, usage = EmbeddingsHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } } return } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: "embedding-2", + Input: request.Input.(string), + } +} + func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go index 1655a59d..2daeb19c 100644 --- a/relay/channel/zhipu/constants.go +++ b/relay/channel/zhipu/constants.go @@ -2,5 +2,5 @@ package zhipu var ModelList = []string{ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", - "glm-4", "glm-4v", "glm-3-turbo", + "glm-4", "glm-4v", "glm-3-turbo", "embedding-2", } diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index a46fd537..f54e0504 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var zhipuResponse EmbeddingRespone + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: response.Model, + Usage: model.Usage{ + PromptTokens: response.PromptTokens, + CompletionTokens: response.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go index b63e1d6f..3c3a7443 100644 --- a/relay/channel/zhipu/model.go +++ b/relay/channel/zhipu/model.go @@ -44,3 +44,21 @@ type tokenData struct { Token string ExpiryTime time.Time } + +type EmbeddingRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +type EmbeddingRespone struct { + Model string `json:"model"` + Object string `json:"object"` + Embeddings []EmbeddingData `json:"data"` + model.Usage `json:"usage"` +} + +type EmbeddingData struct { + Index int `json:"index"` + Object string `json:"object"` + Embedding []float64 `json:"embedding"` +} From f73f2e51dfcf6f15f3d26dd045ad9ae283f25760 Mon Sep 17 00:00:00 2001 From: manjieqi <40858189+manjieqi@users.noreply.github.com> Date: Fri, 5 Apr 2024 00:02:15 +0800 Subject: [PATCH 14/35] feat: update baidu model name & ratio (#1253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修正百度模型名称 * 更新百度模型名称,并保留旧版兼容以及修正单价 * chore: add more model and adjust order --------- Co-authored-by: JustSong --- common/model-ratio.go | 24 ++++++++++++++++-------- relay/channel/baidu/adaptor.go | 24 +++++++++++++++++------- relay/channel/baidu/constants.go | 17 ++++++++++++----- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index d8356dc2..94607c92 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{ "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8K": 0.024 * RMB, - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "bge-large-zh": 0.002 * RMB, - "bge-large-en": 0.002 * RMB, - "bge-large-8k": 0.002 * RMB, + "ERNIE-4.0-8K": 0.120 * RMB, + "ERNIE-Bot-8K-0922": 0.024 * RMB, + "ERNIE-3.5-8K": 0.012 * RMB, + "ERNIE-Lite-8K-0922": 0.008 * RMB, + "ERNIE-Speed-8K": 0.004 * RMB, + "ERNIE-3.5-4K-0205": 0.012 * RMB, + "ERNIE-3.5-8K-0205": 0.024 * RMB, + "ERNIE-3.5-8K-1222": 0.012 * RMB, + "ERNIE-Lite-8K": 0.003 * RMB, + "ERNIE-Speed-128K": 0.004 * RMB, + "ERNIE-Tiny-8K": 0.001 * RMB, + "BLOOMZ-7B": 0.004 * RMB, + "Embedding-V1": 0.002 * RMB, + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 2d2e24f6..72302fdf 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -38,16 +38,26 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { suffix += "completions_pro" case "ERNIE-Bot-4": suffix += "completions_pro" - case "ERNIE-3.5-8K": - suffix += "completions" - case "ERNIE-Bot-8K": - suffix += "ernie_bot_8k" case "ERNIE-Bot": suffix += "completions" - case "ERNIE-Speed": - suffix += "ernie_speed" case "ERNIE-Bot-turbo": suffix += "eb-instant" + case "ERNIE-Speed": + suffix += "ernie_speed" + case "ERNIE-Bot-8K": + suffix += "ernie_bot_8k" + case "ERNIE-4.0-8K": + suffix += "completions_pro" + case "ERNIE-3.5-8K": + suffix += "completions" + case "ERNIE-Speed-8K": + suffix += "ernie_speed" + case "ERNIE-Speed-128K": + suffix += "ernie-speed-128k" + case "ERNIE-Lite-8K": + suffix += "ernie-lite-8k" + case "ERNIE-Tiny-8K": + suffix += "ernie-tiny-8k" case "BLOOMZ-7B": suffix += "bloomz_7b1" case "Embedding-V1": @@ -59,7 +69,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { case "tao-8k": suffix += "tao_8k" default: - suffix += meta.ActualModelName + suffix += strings.ToLower(meta.ActualModelName) } fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) var accessToken string diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go index 45a4e901..ccdc25c3 100644 --- a/relay/channel/baidu/constants.go +++ b/relay/channel/baidu/constants.go @@ -1,11 +1,18 @@ package baidu var ModelList = []string{ - "ERNIE-Bot-4", - "ERNIE-Bot-8K", - "ERNIE-Bot", - "ERNIE-Speed", - "ERNIE-Bot-turbo", + "ERNIE-4.0-8K", + "ERNIE-Bot-8K-0922", + "ERNIE-3.5-8K", + "ERNIE-Lite-8K-0922", + "ERNIE-Speed-8K", + "ERNIE-3.5-4K-0205", + "ERNIE-3.5-8K-0205", + "ERNIE-3.5-8K-1222", + "ERNIE-Lite-8K", + "ERNIE-Speed-128K", + "ERNIE-Tiny-8K", + "BLOOMZ-7B", "Embedding-V1", "bge-large-zh", "bge-large-en", From 1f80b0a39fb728776fdfa635d17fbc90de56baef Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 5 Apr 2024 00:13:37 +0800 Subject: [PATCH 15/35] chore: add omitempty for xunfei functions --- relay/channel/xunfei/model.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go index e9cc59a6..97a43154 100644 --- a/relay/channel/xunfei/model.go +++ b/relay/channel/xunfei/model.go @@ -28,7 +28,7 @@ type ChatRequest struct { } `json:"message"` Functions struct { Text []model.Function `json:"text,omitempty"` - } `json:"functions"` + } `json:"functions,omitempty"` } `json:"payload"` } From 1994256bac48dc7d55d22f9a43577651eb187693 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 5 Apr 2024 00:18:26 +0800 Subject: [PATCH 16/35] chore: disable channel when error message contain quota --- relay/util/common.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/relay/util/common.go b/relay/util/common.go index 535ef680..0bb76909 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -46,6 +46,9 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { return true } + if strings.Contains(err.Message, "quota") { + return true + } return false } From 76569bb0b64d470aee3e970fe8e82557c8931cde Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 5 Apr 2024 00:31:41 +0800 Subject: [PATCH 17/35] chore: disable channel when error message contain credit or balance --- relay/util/common.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/relay/util/common.go b/relay/util/common.go index 0bb76909..d1f79a26 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -49,6 +49,12 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { if strings.Contains(err.Message, "quota") { return true } + if strings.Contains(err.Message, "credit") { + return true + } + if strings.Contains(err.Message, "balance") { + return true + } return false } From 054b00b7250853f9ba345dc5756ae345ec5666bf Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 5 Apr 2024 00:40:48 +0800 Subject: [PATCH 18/35] docs: add API docs --- README.md | 1 + docs/API.md | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 docs/API.md diff --git a/README.md b/README.md index 2dcdbd4f..53847b45 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 +25. 支持**扩展**,详情请参考此处 [API 文档](./docs/API.md)。 ## 部署 ### 基于 Docker 进行部署 diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 00000000..9fc350ef --- /dev/null +++ b/docs/API.md @@ -0,0 +1,17 @@ +# 使用 API 操控 & 扩展 One API +> 欢迎提交 PR 在此放上你的拓展项目。 + +## 鉴权 +One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: + +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c) + +之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8) + +## API 列表 +> 当前 API 列表不全,请自行通过浏览器抓取前端请求 + +欢迎此处 PR 补充。 + +如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 \ No newline at end of file From 0a37aa4cbd322e7ff44446ae5f130a43a90a630e Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 5 Apr 2024 01:10:30 +0800 Subject: [PATCH 19/35] docs: add API docs --- controller/user.go | 77 ++++++++++++++++++++++++++++++++------------ docs/API.md | 31 ++++++++++++++++-- model/log.go | 15 +++++++++ router/api-router.go | 1 + 4 files changed, 101 insertions(+), 23 deletions(-) diff --git a/controller/user.go b/controller/user.go index 8b614e5d..61055878 100644 --- a/controller/user.go +++ b/controller/user.go @@ -180,27 +180,27 @@ func Register(c *gin.Context) { } func GetAllUsers(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - if p < 0 { - p = 0 - } - - order := c.DefaultQuery("order", "") - users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) - - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": users, - }) + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + order := c.DefaultQuery("order", "") + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": users, + }) } func SearchUsers(c *gin.Context) { @@ -770,3 +770,38 @@ func TopUp(c *gin.Context) { }) return } + +type adminTopUpRequest struct { + UserId int `json:"user_id"` + Quota int `json:"quota"` + Remark string `json:"remark"` +} + +func AdminTopUp(c *gin.Context) { + req := adminTopUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if req.Remark == "" { + req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) + } + model.RecordTopupLog(req.UserId, req.Remark, req.Quota) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/docs/API.md b/docs/API.md index 9fc350ef..72ae7d91 100644 --- a/docs/API.md +++ b/docs/API.md @@ -1,6 +1,10 @@ # 使用 API 操控 & 扩展 One API > 欢迎提交 PR 在此放上你的拓展项目。 +例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 + +又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 + ## 鉴权 One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: @@ -9,9 +13,32 @@ One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下 之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: ![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8) +## 请求格式与响应格式 +One API 使用 JSON 格式进行请求和响应。 + +对于响应体,一般格式如下: +```json +{ + "message": "请求信息", + "success": true, + "data": {} +} +``` + ## API 列表 > 当前 API 列表不全,请自行通过浏览器抓取前端请求 -欢迎此处 PR 补充。 +如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 -如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 \ No newline at end of file +### 获取当前登录用户信息 +**GET** `/api/user/self` + +### 为给定用户充值额度 +**POST** `/api/topup` +```json +{ + "user_id": 1, + "quota": 100000, + "remark": "充值 100000 额度" +} +``` \ No newline at end of file diff --git a/model/log.go b/model/log.go index 4409f73e..6b679c36 100644 --- a/model/log.go +++ b/model/log.go @@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) { } } +func RecordTopupLog(userId int, content string, quota int) { + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: helper.GetTimestamp(), + Type: LogTypeTopup, + Content: content, + Quota: quota, + } + err := LOG_DB.Create(log).Error + if err != nil { + logger.SysError("failed to record log: " + err.Error()) + } +} + func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { diff --git a/router/api-router.go b/router/api-router.go index 4aa6d830..1558640f 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -26,6 +26,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) + apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) userRoute := apiRouter.Group("/user") { From f8cc63f00b47a2279091e122f8815050262a31e2 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 5 Apr 2024 01:23:11 +0800 Subject: [PATCH 20/35] feat: add user info to topup link --- web/default/src/pages/TopUp/index.js | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/web/default/src/pages/TopUp/index.js b/web/default/src/pages/TopUp/index.js index f52cb8d5..2fcf0eae 100644 --- a/web/default/src/pages/TopUp/index.js +++ b/web/default/src/pages/TopUp/index.js @@ -8,6 +8,7 @@ const TopUp = () => { const [topUpLink, setTopUpLink] = useState(''); const [userQuota, setUserQuota] = useState(0); const [isSubmitting, setIsSubmitting] = useState(false); + const [user, setUser] = useState({}); const topUp = async () => { if (redemptionCode === '') { @@ -41,7 +42,14 @@ const TopUp = () => { showError('超级管理员未设置充值链接!'); return; } - window.open(topUpLink, '_blank'); + let url = new URL(topUpLink); + let username = user.username; + let user_id = user.id; + // add username and user_id to the topup link + url.searchParams.append('username', username); + url.searchParams.append('user_id', user_id); + url.searchParams.append('transaction_id', crypto.randomUUID()); + window.open(url.toString(), '_blank'); }; const getUserQuota = async ()=>{ @@ -49,6 +57,7 @@ const TopUp = () => { const {success, message, data} = res.data; if (success) { setUserQuota(data.quota); + setUser(data); } else { showError(message); } @@ -80,7 +89,7 @@ const TopUp = () => { }} /> + ) : ( + <> + )} {status.wechat_login ? ( ) } + { + status.lark_client_id && ( + + ) + }