From d0a0e871e117591fdf018e619c5d71ee371331f6 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 3 Sep 2023 22:12:35 +0800 Subject: [PATCH] fix: support ali's embedding model (#481, close #469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat:支持阿里的 embedding 模型 * fix: add to model list --------- Co-authored-by: JustSong Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com> --- common/model-ratio.go | 7 ++- controller/model.go | 9 +++ controller/relay-ali.go | 88 ++++++++++++++++++++++++++++ controller/relay-baidu.go | 15 +---- controller/relay-text.go | 24 +++++++- controller/relay.go | 19 ++++++ web/src/pages/Channel/EditChannel.js | 2 +- 7 files changed, 144 insertions(+), 20 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 70758805..eeb23e07 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag - "qwen-plus-v1": 0.5715, // Same as above - "SparkDesk": 0.8572, // TBD + "qwen-v1": 0.8572, // ¥0.012 / 1k tokens + "qwen-plus-v1": 1, // ¥0.014 / 1k tokens + "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 88f95f7b..637ebe10 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,15 @@ func init() { Root: "qwen-plus-v1", Parent: nil, }, + { + Id: "text-embedding-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "text-embedding-v1", + Parent: nil, + }, { Id: "SparkDesk", Object: "model", diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 9dca9a89..50dc743c 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -35,6 +35,29 @@ type AliChatRequest struct { Parameters AliParameters `json:"parameters,omitempty"` } +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + type AliError struct { Code string `json:"code"` Message string `json:"message"` @@ -44,6 +67,7 @@ type AliError struct { type AliUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` } type AliOutput struct { @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { } } +func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { + return &AliEmbeddingRequest{ + Model: "text-embedding-v1", + Input: struct { + Texts []string `json:"texts"` + }{ + Texts: request.ParseInput(), + }, + } +} + +func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var aliResponse AliEmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&aliResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Code != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Model: "text-embedding-v1", + Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + } + + for _, item := range response.Output.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { choice := OpenAITextResponseChoice{ Index: 0, diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 39f31a9a..ed08ac04 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom } func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - baiduEmbeddingRequest := BaiduEmbeddingRequest{ - Input: nil, + return &BaiduEmbeddingRequest{ + Input: request.ParseInput(), } - switch request.Input.(type) { - case string: - baiduEmbeddingRequest.Input = []string{request.Input.(string)} - case []any: - for _, item := range request.Input.([]any) { - if str, ok := item.(string); ok { - baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) - } - } - } - return &baiduEmbeddingRequest } func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { diff --git a/controller/relay-text.go b/controller/relay-text.go index c6659799..b190a999 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -174,6 +174,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } case APITypeAIProxyLibrary: fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } @@ -262,8 +265,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeAli: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err := json.Marshal(aliRequest) + var jsonStr []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliEmbeddingRequest) + default: + aliRequest := requestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliRequest) + } if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } @@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := aliHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + err, usage = aliHandler(c, resp) + } if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 056d42d3..d20663f6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct { Functions any `json:"functions,omitempty"` } +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 2ad3dc6a..78ff1952 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -67,7 +67,7 @@ const EditChannel = () => { localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; break; case 17: - localModels = ['qwen-v1', 'qwen-plus-v1']; + localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1']; break; case 16: localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];