feat: add support for gemini embedding-001

This commit is contained in:
devocy 2024-02-26 14:56:24 +08:00
parent 6b27d6659a
commit 7038d2a71b
5 changed files with 126 additions and 5 deletions

View File

@ -127,6 +127,7 @@ var ModelRatio = map[string]float64{
"moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB,
"moonshot-v1-128k": 0.06 * RMB,
"embedding-001": 0.01 * RMB,
}
func ModelRatio2JSONString() string {

View File

@ -7,6 +7,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
relaymode "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@ -17,15 +18,19 @@ type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
fmt.Println(meta.APIVersion)
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent"
if meta.IsStream {
if relaymode.RelayModeEmbeddings == meta.Mode {
action = "batchEmbedContents"
} else if meta.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
}
@ -39,7 +44,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
if relaymode.RelayModeEmbeddings == relayMode {
return ConvertEmbeddingRequest(*request), nil
} else {
return ConvertRequest(*request), nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
@ -47,7 +57,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
if relaymode.RelayModeEmbeddings == meta.Mode {
err, usage = EmbeddingHandler(c, resp, meta.PromptTokens, meta.ActualModelName)
} else if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)

View File

@ -3,4 +3,5 @@ package gemini
var ModelList = []string{
"gemini-pro",
"gemini-pro-vision",
"(Gemini)embedding-001",
}

View File

@ -25,7 +25,7 @@ const (
VisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
// ConvertRequest Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
@ -122,6 +122,27 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
return &geminiRequest
}
// ConvertEmbeddingRequest converts a GeneralOpenAIRequest to an EmbeddingMultiRequest
func ConvertEmbeddingRequest(textRequest model.GeneralOpenAIRequest) *EmbeddingMultiRequest {
inputs := textRequest.ParseInput()
requests := make([]EmbeddingRequest, 0, len(inputs))
for _, input := range inputs {
requests = append(requests, EmbeddingRequest{
Model: "models/embedding-001",
Content: ChatContent{
Parts: []Part{
{
Text: input,
},
},
},
})
}
return &EmbeddingMultiRequest{
Requests: requests,
}
}
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
@ -258,6 +279,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, responseText
}
// EmbeddingHandler is a function that handles embedding requests
func EmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
var geminiError Error
err = json.Unmarshal(body, &geminiError)
if geminiError.Code != 0 || err != nil {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: geminiError.Message,
Type: geminiError.Details[0].Type,
Param: geminiError.Status,
Code: geminiError.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
var geminiResponse EmbeddingResponse
err = json.Unmarshal(body, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiResponse, promptTokens, modelName)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@ -301,3 +361,21 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func embeddingResponseGemini2OpenAI(geminiResponse *EmbeddingResponse, promptTokens int, modelName string) *openai.EmbeddingResponse {
data := make([]openai.EmbeddingResponseItem, 0, len(geminiResponse.Embeddings))
for index, embedding := range geminiResponse.Embeddings {
data = append(data, openai.EmbeddingResponseItem{
Object: "embedding",
Embedding: embedding.Values,
Index: index,
})
}
return &openai.EmbeddingResponse{
Object: "list",
Data: data,
Model: modelName,
Usage: model.Usage{TotalTokens: promptTokens},
}
}

View File

@ -39,3 +39,32 @@ type ChatGenerationConfig struct {
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []struct {
Type string `json:"@type"`
Reason string `json:"reason"`
Domain string `json:"domain"`
Metadata map[string]string `json:"metadata"`
} `json:"details"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Content ChatContent `json:"content"`
}
type EmbeddingMultiRequest struct {
Requests []EmbeddingRequest `json:"requests"`
}
type EmbeddingResponse struct {
Embeddings []EmbeddingData `json:"embeddings"`
}
type EmbeddingData struct {
Values []float64 `json:"values"`
}