From 6f036bd0c937afc9e477d421dd8c3113424f313b Mon Sep 17 00:00:00 2001 From: Yang Fei Date: Thu, 4 Apr 2024 23:32:59 +0800 Subject: [PATCH] feat: add embedding-2 support for zhipu (#1273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加对智谱embedding-2模型的支持 * fix: fix usage & ratio --------- Co-authored-by: yangfei Co-authored-by: JustSong --- common/model-ratio.go | 1 + relay/channel/zhipu/adaptor.go | 44 ++++++++++++++++++++++-------- relay/channel/zhipu/constants.go | 2 +- relay/channel/zhipu/main.go | 47 ++++++++++++++++++++++++++++++++ relay/channel/zhipu/model.go | 18 ++++++++++++ 5 files changed, 100 insertions(+), 12 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index aa75042e..d8356dc2 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -91,6 +91,7 @@ var ModelRatio = map[string]float64{ "glm-4": 0.1 * RMB, "glm-4v": 0.1 * RMB, "glm-3-turbo": 0.005 * RMB, + "embedding-2": 0.0005 * RMB, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 0ca23d59..7b570e71 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { if a.APIVersion == "v4" { return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil } + if meta.Mode == constant.RelayModeEmbeddings { + return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil + } method := "invoke" if meta.IsStream { method = "sse-invoke" @@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + // TopP (0.0, 1.0) + request.TopP = math.Min(0.99, request.TopP) + request.TopP = math.Max(0.01, request.TopP) - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) - a.SetVersionByModeName(request.Model) - if a.APIVersion == "v4" { - return request, nil + // Temperature (0.0, 1.0) + request.Temperature = math.Min(0.99, request.Temperature) + request.Temperature = math.Max(0.01, request.Temperature) + a.SetVersionByModeName(request.Model) + if a.APIVersion == "v4" { + return request, nil + } + return ConvertRequest(*request), nil } - return ConvertRequest(*request), nil } func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { @@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel if a.APIVersion == "v4" { return a.DoResponseV4(c, resp, meta) } + if meta.IsStream { err, usage = StreamHandler(c, resp) } else { - err, usage = Handler(c, resp) + if meta.Mode == constant.RelayModeEmbeddings { + err, usage = EmbeddingsHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } } return } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: "embedding-2", + Input: request.Input.(string), + } +} + func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go index 1655a59d..2daeb19c 100644 --- a/relay/channel/zhipu/constants.go +++ b/relay/channel/zhipu/constants.go @@ -2,5 +2,5 @@ package zhipu var ModelList = []string{ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", - "glm-4", "glm-4v", "glm-3-turbo", + "glm-4", "glm-4v", "glm-3-turbo", "embedding-2", } diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index a46fd537..f54e0504 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var zhipuResponse EmbeddingRespone + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: response.Model, + Usage: model.Usage{ + PromptTokens: response.PromptTokens, + CompletionTokens: response.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go index b63e1d6f..3c3a7443 100644 --- a/relay/channel/zhipu/model.go +++ b/relay/channel/zhipu/model.go @@ -44,3 +44,21 @@ type tokenData struct { Token string ExpiryTime time.Time } + +type EmbeddingRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +type EmbeddingRespone struct { + Model string `json:"model"` + Object string `json:"object"` + Embeddings []EmbeddingData `json:"data"` + model.Usage `json:"usage"` +} + +type EmbeddingData struct { + Index int `json:"index"` + Object string `json:"object"` + Embedding []float64 `json:"embedding"` +}