feat:zhipu support text_embedding

This commit is contained in:
igophper 2023-09-21 17:36:37 +08:00
parent 8651451e53
commit 4c7c3b6999
5 changed files with 107 additions and 4 deletions

View File

@ -50,6 +50,7 @@ var ModelRatio = map[string]float64{
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"text_embedding": 0.0357, // ¥0.0005 / 1k tokens
"qwen-v1": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus-v1": 1, // ¥0.014 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens

View File

@ -342,6 +342,15 @@ func init() {
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "text_embedding",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "text_embedding",
Parent: nil,
},
{
Id: "qwen-v1",
Object: "model",

View File

@ -173,6 +173,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.Stream {
method = "sse-invoke"
}
if relayMode == RelayModeEmbeddings {
method = "invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
@ -261,8 +264,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
zhipuEmbeddingRequest := embeddingRequestOpenAI2Zhipu(textRequest)
jsonStr, err = json.Marshal(zhipuEmbeddingRequest)
default:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err = json.Marshal(zhipuRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = zhipuEmbeddingHandler(c, resp)
default:
err, usage = zhipuHandler(c, resp)
}
if err != nil {
return err
}

View File

@ -58,6 +58,25 @@ type zhipuTokenData struct {
ExpiryTime time.Time
}
type ZhipuEmbeddingRequest struct {
Prompt string `json:"prompt"`
}
type ZhipuEmbeddingResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Embedding []float64 `json:"embedding"`
Usage `json:"usage"`
}
type ZhipuEmbeddingResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuEmbeddingResponseData `json:"data"`
}
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
@ -108,6 +127,27 @@ func getZhipuToken(apikey string) string {
return tokenString
}
func embeddingRequestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuEmbeddingRequest {
return &ZhipuEmbeddingRequest{
Prompt: request.ParseInput()[0], // 智谱只支持一行input
}
}
func embeddingResponseZhipu2OpenAI(response *ZhipuEmbeddingResponse) *OpenAIEmbeddingResponse {
return &OpenAIEmbeddingResponse{
Object: "list",
Data: []OpenAIEmbeddingResponseItem{
{
Object: `embedding`,
Index: 0,
Embedding: response.Data.Embedding,
},
},
Model: "text_embedding",
Usage: response.Data.Usage,
}
}
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
messages := make([]ZhipuMessage, 0, len(request.Messages))
for _, message := range request.Messages {
@ -299,3 +339,38 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func zhipuEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var zhipuResponse ZhipuEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&zhipuResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
Code: zhipuResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@ -70,7 +70,7 @@ const EditChannel = () => {
localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite', 'text_embedding'];
break;
case 18:
localModels = ['SparkDesk'];