Refactor Gemini Adaptor to Support Embeddings
This commit is contained in:
parent
b53e00a9b3
commit
5dee10cacf
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
@ -24,7 +25,14 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
|
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
|
||||||
action := "generateContent"
|
action := ""
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.Embeddings:
|
||||||
|
action = "batchEmbedContents"
|
||||||
|
default:
|
||||||
|
action = "generateContent"
|
||||||
|
}
|
||||||
|
|
||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
action = "streamGenerateContent?alt=sse"
|
action = "streamGenerateContent?alt=sse"
|
||||||
}
|
}
|
||||||
@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
return ConvertRequest(*request), nil
|
switch relayMode {
|
||||||
|
case relaymode.Embeddings:
|
||||||
|
geminiEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||||
|
return geminiEmbeddingRequest, nil
|
||||||
|
default:
|
||||||
|
geminiRequest := ConvertRequest(*request)
|
||||||
|
return geminiRequest, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
@ -61,7 +76,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
|||||||
err, responseText = StreamHandler(c, resp)
|
err, responseText = StreamHandler(c, resp)
|
||||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
switch meta.Mode {
|
||||||
|
case relaymode.Embeddings:
|
||||||
|
err, usage = EmbeddingHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -134,6 +134,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
return &geminiRequest
|
return &geminiRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
|
||||||
|
inputs := request.ParseInput()
|
||||||
|
requests := make([]EmbeddingRequest, len(inputs))
|
||||||
|
model := fmt.Sprintf("models/%s", request.Model)
|
||||||
|
|
||||||
|
for i, input := range inputs {
|
||||||
|
requests[i] = EmbeddingRequest{
|
||||||
|
Model: model,
|
||||||
|
Content: ChatContent{
|
||||||
|
Parts: []Part{
|
||||||
|
{
|
||||||
|
Text: input,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &BatchEmbeddingRequest{
|
||||||
|
Requests: requests,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Candidates []ChatCandidate `json:"candidates"`
|
Candidates []ChatCandidate `json:"candidates"`
|
||||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||||
@ -230,6 +253,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||||
|
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
|
||||||
|
Model: "gemini-embedding",
|
||||||
|
Usage: model.Usage{TotalTokens: 0},
|
||||||
|
}
|
||||||
|
for _, item := range response.Embeddings {
|
||||||
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||||
|
Object: `embedding`,
|
||||||
|
Index: 0,
|
||||||
|
Embedding: item.Values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &openAIEmbeddingResponse
|
||||||
|
}
|
||||||
|
|
||||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
@ -337,3 +377,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
var geminiEmbeddingResponse EmbeddingResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if geminiEmbeddingResponse.Error != nil {
|
||||||
|
return &model.ErrorWithStatusCode{
|
||||||
|
Error: model.Error{
|
||||||
|
Message: geminiEmbeddingResponse.Error.Message,
|
||||||
|
Type: "gemini_error",
|
||||||
|
Param: "",
|
||||||
|
Code: geminiEmbeddingResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
@ -7,6 +7,33 @@ type ChatRequest struct {
|
|||||||
Tools []ChatTools `json:"tools,omitempty"`
|
Tools []ChatTools `json:"tools,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Content ChatContent `json:"content"`
|
||||||
|
TaskType string `json:"taskType,omitempty"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BatchEmbeddingRequest struct {
|
||||||
|
Requests []EmbeddingRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingData struct {
|
||||||
|
Values []float64 `json:"values"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Embeddings []EmbeddingData `json:"embeddings"`
|
||||||
|
Error *Error `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Code int `json:"code,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type InlineData struct {
|
type InlineData struct {
|
||||||
MimeType string `json:"mimeType"`
|
MimeType string `json:"mimeType"`
|
||||||
Data string `json:"data"`
|
Data string `json:"data"`
|
||||||
|
Loading…
Reference in New Issue
Block a user