diff --git a/common/model-ratio.go b/common/model-ratio.go index ba6d7245..123451f7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -42,6 +42,7 @@ var ModelRatio = map[string]float64{ "claude-2": 30, "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/model.go b/controller/model.go index f8096f75..123b0a2f 100644 --- a/controller/model.go +++ b/controller/model.go @@ -288,6 +288,15 @@ func init() { Root: "ERNIE-Bot-turbo", Parent: nil, }, + { + Id: "Embedding-V1", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "Embedding-V1", + Parent: nil, + }, { Id: "PaLM-2", Object: "model", diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 4267757d..7960e8ee 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -54,6 +54,25 @@ type BaiduChatStreamResponse struct { IsEnd bool `json:"is_end"` } +type BaiduEmbeddingRequest struct { + Input []string `json:"input"` +} + +type BaiduEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type BaiduEmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []BaiduEmbeddingData `json:"data"` + Usage Usage `json:"usage"` + BaiduError +} + func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { @@ -112,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom return &response } +func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { + baiduEmbeddingRequest := BaiduEmbeddingRequest{ + Input: nil, + } + switch request.Input.(type) { + case string: + baiduEmbeddingRequest.Input = []string{request.Input.(string)} + case []string: + baiduEmbeddingRequest.Input = request.Input.([]string) + } + return &baiduEmbeddingRequest +} + +func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Model: "baidu-embedding", + Usage: response.Usage, + } + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { var usage Usage scanner := bufio.NewScanner(resp.Body) @@ -212,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var baiduResponse BaiduEmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if baiduResponse.ErrorMsg != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index e58c810b..7d3fe1de 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -139,6 +139,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" case "BLOOMZ-7B": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "Embedding-V1": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -212,12 +214,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeBaidu: - baiduRequest := requestOpenAI2Baidu(textRequest) - jsonStr, err := json.Marshal(baiduRequest) + var jsonData []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) + jsonData, err = json.Marshal(baiduEmbeddingRequest) + default: + baiduRequest := requestOpenAI2Baidu(textRequest) + jsonData, err = json.Marshal(baiduRequest) + } if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } - requestBody = bytes.NewBuffer(jsonStr) + requestBody = bytes.NewBuffer(jsonData) case APITypePaLM: palmRequest := requestOpenAI2PaLM(textRequest) jsonStr, err := json.Marshal(palmRequest) @@ -386,7 +396,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := baiduHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = baiduEmbeddingHandler(c, resp) + default: + err, usage = baiduHandler(c, resp) + } if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 9cfa5c4f..609ae2eb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -99,6 +99,19 @@ type OpenAITextResponse struct { Usage `json:"usage"` } +type OpenAIEmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type OpenAIEmbeddingResponse struct { + Object string `json:"object"` + Data []OpenAIEmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` +} + type ImageResponse struct { Created int `json:"created"` Data []struct {