feat:zhipu support text_embedding
This commit is contained in:
parent
8651451e53
commit
4c7c3b6999
@ -50,6 +50,7 @@ var ModelRatio = map[string]float64{
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"text_embedding": 0.0357, // ¥0.0005 / 1k tokens
|
||||
"qwen-v1": 0.8572, // ¥0.012 / 1k tokens
|
||||
"qwen-plus-v1": 1, // ¥0.014 / 1k tokens
|
||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||
|
@ -342,6 +342,15 @@ func init() {
|
||||
Root: "chatglm_lite",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text_embedding",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "zhipu",
|
||||
Permission: permission,
|
||||
Root: "text_embedding",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-v1",
|
||||
Object: "model",
|
||||
|
@ -173,6 +173,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if textRequest.Stream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
if relayMode == RelayModeEmbeddings {
|
||||
method = "invoke"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||
case APITypeAli:
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
@ -261,8 +264,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeZhipu:
|
||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
||||
jsonStr, err := json.Marshal(zhipuRequest)
|
||||
var jsonStr []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
zhipuEmbeddingRequest := embeddingRequestOpenAI2Zhipu(textRequest)
|
||||
jsonStr, err = json.Marshal(zhipuEmbeddingRequest)
|
||||
default:
|
||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
||||
jsonStr, err = json.Marshal(zhipuRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
} else {
|
||||
err, usage := zhipuHandler(c, resp)
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
err, usage = zhipuEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = zhipuHandler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -58,6 +58,25 @@ type zhipuTokenData struct {
|
||||
ExpiryTime time.Time
|
||||
}
|
||||
|
||||
type ZhipuEmbeddingRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ZhipuEmbeddingResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ZhipuEmbeddingResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ZhipuEmbeddingResponseData `json:"data"`
|
||||
}
|
||||
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
@ -108,6 +127,27 @@ func getZhipuToken(apikey string) string {
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuEmbeddingRequest {
|
||||
return &ZhipuEmbeddingRequest{
|
||||
Prompt: request.ParseInput()[0], // 智谱只支持一行input
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingResponseZhipu2OpenAI(response *ZhipuEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
return &OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: []OpenAIEmbeddingResponseItem{
|
||||
{
|
||||
Object: `embedding`,
|
||||
Index: 0,
|
||||
Embedding: response.Data.Embedding,
|
||||
},
|
||||
},
|
||||
Model: "text_embedding",
|
||||
Usage: response.Data.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
@ -299,3 +339,38 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func zhipuEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var zhipuResponse ZhipuEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&zhipuResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if !zhipuResponse.Success {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
Code: zhipuResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ const EditChannel = () => {
|
||||
localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
|
||||
break;
|
||||
case 16:
|
||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite', 'text_embedding'];
|
||||
break;
|
||||
case 18:
|
||||
localModels = ['SparkDesk'];
|
||||
|
Loading…
Reference in New Issue
Block a user