diff --git a/common/model-ratio.go b/common/model-ratio.go index 2e7aae71..3f2e1fb1 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -127,6 +127,7 @@ var ModelRatio = map[string]float64{ "moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-32k": 0.024 * RMB, "moonshot-v1-128k": 0.06 * RMB, + "embedding-001": 0.01 * RMB, } func ModelRatio2JSONString() string { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index f3305e5d..554ee349 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/common/helper" channelhelper "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + relaymode "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -17,15 +18,19 @@ type Adaptor struct { } func (a *Adaptor) Init(meta *util.RelayMeta) { - + fmt.Println(meta.APIVersion) } func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { version := helper.AssignOrDefault(meta.APIVersion, "v1") action := "generateContent" - if meta.IsStream { + + if relaymode.RelayModeEmbeddings == meta.Mode { + action = "batchEmbedContents" + } else if meta.IsStream { action = "streamGenerateContent" } + return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } @@ -39,7 +44,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - return ConvertRequest(*request), nil + + if relaymode.RelayModeEmbeddings == relayMode { + return ConvertEmbeddingRequest(*request), nil + } else { + return ConvertRequest(*request), nil + } } func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { @@ -47,7 +57,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { + if relaymode.RelayModeEmbeddings == meta.Mode { + err, usage = EmbeddingHandler(c, resp, meta.PromptTokens, meta.ActualModelName) + } else if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go index 5bb0c168..d35f22e0 100644 --- a/relay/channel/gemini/constants.go +++ b/relay/channel/gemini/constants.go @@ -3,4 +3,5 @@ package gemini var ModelList = []string{ "gemini-pro", "gemini-pro-vision", + "(Gemini)embedding-001", } diff --git a/relay/channel/gemini/main.go b/relay/channel/gemini/main.go index c24694c8..60f8ed6f 100644 --- a/relay/channel/gemini/main.go +++ b/relay/channel/gemini/main.go @@ -25,7 +25,7 @@ const ( VisionMaxImageNum = 16 ) -// Setting safety to the lowest possible values since Gemini is already powerless enough +// ConvertRequest Setting safety to the lowest possible values since Gemini is already powerless enough func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { geminiRequest := ChatRequest{ Contents: make([]ChatContent, 0, len(textRequest.Messages)), @@ -122,6 +122,27 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { return &geminiRequest } +// ConvertEmbeddingRequest converts a GeneralOpenAIRequest to an EmbeddingMultiRequest +func ConvertEmbeddingRequest(textRequest model.GeneralOpenAIRequest) *EmbeddingMultiRequest { + inputs := textRequest.ParseInput() + requests := make([]EmbeddingRequest, 0, len(inputs)) + for _, input := range inputs { + requests = append(requests, EmbeddingRequest{ + Model: "models/embedding-001", + Content: ChatContent{ + Parts: []Part{ + { + Text: input, + }, + }, + }, + }) + } + return &EmbeddingMultiRequest{ + Requests: requests, + } +} + type ChatResponse struct { Candidates []ChatCandidate `json:"candidates"` PromptFeedback ChatPromptFeedback `json:"promptFeedback"` @@ -258,6 +279,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return nil, responseText } +// EmbeddingHandler is a function that handles embedding requests +func EmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + var geminiError Error + err = json.Unmarshal(body, &geminiError) + if geminiError.Code != 0 || err != nil { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: geminiError.Message, + Type: geminiError.Details[0].Type, + Param: geminiError.Status, + Code: geminiError.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + var geminiResponse EmbeddingResponse + err = json.Unmarshal(body, &geminiResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := embeddingResponseGemini2OpenAI(&geminiResponse, promptTokens, modelName) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -301,3 +361,21 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st _, err = c.Writer.Write(jsonResponse) return nil, &usage } + +func embeddingResponseGemini2OpenAI(geminiResponse *EmbeddingResponse, promptTokens int, modelName string) *openai.EmbeddingResponse { + data := make([]openai.EmbeddingResponseItem, 0, len(geminiResponse.Embeddings)) + + for index, embedding := range geminiResponse.Embeddings { + data = append(data, openai.EmbeddingResponseItem{ + Object: "embedding", + Embedding: embedding.Values, + Index: index, + }) + } + return &openai.EmbeddingResponse{ + Object: "list", + Data: data, + Model: modelName, + Usage: model.Usage{TotalTokens: promptTokens}, + } +} diff --git a/relay/channel/gemini/model.go b/relay/channel/gemini/model.go index d1e3c4fd..e4ae1065 100644 --- a/relay/channel/gemini/model.go +++ b/relay/channel/gemini/model.go @@ -39,3 +39,32 @@ type ChatGenerationConfig struct { CandidateCount int `json:"candidateCount,omitempty"` StopSequences []string `json:"stopSequences,omitempty"` } + +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + Details []struct { + Type string `json:"@type"` + Reason string `json:"reason"` + Domain string `json:"domain"` + Metadata map[string]string `json:"metadata"` + } `json:"details"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Content ChatContent `json:"content"` +} + +type EmbeddingMultiRequest struct { + Requests []EmbeddingRequest `json:"requests"` +} + +type EmbeddingResponse struct { + Embeddings []EmbeddingData `json:"embeddings"` +} + +type EmbeddingData struct { + Values []float64 `json:"values"` +}