From 4c7c3b699965847a793b74bf7561f51350de2926 Mon Sep 17 00:00:00 2001 From: igophper Date: Thu, 21 Sep 2023 17:36:37 +0800 Subject: [PATCH] feat:zhipu support text_embedding --- common/model-ratio.go | 1 + controller/model.go | 9 ++++ controller/relay-text.go | 24 +++++++-- controller/relay-zhipu.go | 75 ++++++++++++++++++++++++++++ web/src/pages/Channel/EditChannel.js | 2 +- 5 files changed, 107 insertions(+), 4 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index eeb23e07..9cfe74ca 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -50,6 +50,7 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "text_embedding": 0.0357, // ¥0.0005 / 1k tokens "qwen-v1": 0.8572, // ¥0.012 / 1k tokens "qwen-plus-v1": 1, // ¥0.014 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 637ebe10..47af070c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -342,6 +342,15 @@ func init() { Root: "chatglm_lite", Parent: nil, }, + { + Id: "text_embedding", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "text_embedding", + Parent: nil, + }, { Id: "qwen-v1", Object: "model", diff --git a/controller/relay-text.go b/controller/relay-text.go index 5a5f355b..f8e612f0 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -173,6 +173,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { method = "sse-invoke" } + if relayMode == RelayModeEmbeddings { + method = "invoke" + } fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" @@ -261,8 +264,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: - zhipuRequest := requestOpenAI2Zhipu(textRequest) - jsonStr, err := json.Marshal(zhipuRequest) + var jsonStr []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + zhipuEmbeddingRequest := embeddingRequestOpenAI2Zhipu(textRequest) + jsonStr, err = json.Marshal(zhipuEmbeddingRequest) + default: + zhipuRequest := requestOpenAI2Zhipu(textRequest) + jsonStr, err = json.Marshal(zhipuRequest) + } if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } @@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens return nil } else { - err, usage := zhipuHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = zhipuEmbeddingHandler(c, resp) + default: + err, usage = zhipuHandler(c, resp) + } if err != nil { return err } diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 7a4a582d..b7376bd0 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -58,6 +58,25 @@ type zhipuTokenData struct { ExpiryTime time.Time } +type ZhipuEmbeddingRequest struct { + Prompt string `json:"prompt"` +} + +type ZhipuEmbeddingResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Embedding []float64 `json:"embedding"` + Usage `json:"usage"` +} + +type ZhipuEmbeddingResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ZhipuEmbeddingResponseData `json:"data"` +} + var zhipuTokens sync.Map var expSeconds int64 = 24 * 3600 @@ -108,6 +127,27 @@ func getZhipuToken(apikey string) string { return tokenString } +func embeddingRequestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuEmbeddingRequest { + return &ZhipuEmbeddingRequest{ + Prompt: request.ParseInput()[0], // 智谱只支持一行input + } +} + +func embeddingResponseZhipu2OpenAI(response *ZhipuEmbeddingResponse) *OpenAIEmbeddingResponse { + return &OpenAIEmbeddingResponse{ + Object: "list", + Data: []OpenAIEmbeddingResponseItem{ + { + Object: `embedding`, + Index: 0, + Embedding: response.Data.Embedding, + }, + }, + Model: "text_embedding", + Usage: response.Data.Usage, + } +} + func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { messages := make([]ZhipuMessage, 0, len(request.Messages)) for _, message := range request.Messages { @@ -299,3 +339,38 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func zhipuEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var zhipuResponse ZhipuEmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&zhipuResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if !zhipuResponse.Success { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: zhipuResponse.Msg, + Type: "zhipu_error", + Param: "", + Code: zhipuResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 4c8dd0c4..70802607 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -70,7 +70,7 @@ const EditChannel = () => { localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1']; break; case 16: - localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; + localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite', 'text_embedding']; break; case 18: localModels = ['SparkDesk'];