feat: add support for gemini embedding-001
This commit is contained in:
parent
6b27d6659a
commit
7038d2a71b
@ -127,6 +127,7 @@ var ModelRatio = map[string]float64{
|
||||
"moonshot-v1-8k": 0.012 * RMB,
|
||||
"moonshot-v1-32k": 0.024 * RMB,
|
||||
"moonshot-v1-128k": 0.06 * RMB,
|
||||
"embedding-001": 0.01 * RMB,
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
relaymode "github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
@ -17,15 +18,19 @@ type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
|
||||
fmt.Println(meta.APIVersion)
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
||||
action := "generateContent"
|
||||
if meta.IsStream {
|
||||
|
||||
if relaymode.RelayModeEmbeddings == meta.Mode {
|
||||
action = "batchEmbedContents"
|
||||
} else if meta.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||
}
|
||||
|
||||
@ -39,7 +44,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
|
||||
if relaymode.RelayModeEmbeddings == relayMode {
|
||||
return ConvertEmbeddingRequest(*request), nil
|
||||
} else {
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
@ -47,7 +57,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
if relaymode.RelayModeEmbeddings == meta.Mode {
|
||||
err, usage = EmbeddingHandler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
} else if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
|
@ -3,4 +3,5 @@ package gemini
|
||||
var ModelList = []string{
|
||||
"gemini-pro",
|
||||
"gemini-pro-vision",
|
||||
"(Gemini)embedding-001",
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ const (
|
||||
VisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
// ConvertRequest Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest := ChatRequest{
|
||||
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
||||
@ -122,6 +122,27 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
// ConvertEmbeddingRequest converts a GeneralOpenAIRequest to an EmbeddingMultiRequest
|
||||
func ConvertEmbeddingRequest(textRequest model.GeneralOpenAIRequest) *EmbeddingMultiRequest {
|
||||
inputs := textRequest.ParseInput()
|
||||
requests := make([]EmbeddingRequest, 0, len(inputs))
|
||||
for _, input := range inputs {
|
||||
requests = append(requests, EmbeddingRequest{
|
||||
Model: "models/embedding-001",
|
||||
Content: ChatContent{
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: input,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return &EmbeddingMultiRequest{
|
||||
Requests: requests,
|
||||
}
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatCandidate `json:"candidates"`
|
||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||
@ -258,6 +279,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
// EmbeddingHandler is a function that handles embedding requests
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiError Error
|
||||
err = json.Unmarshal(body, &geminiError)
|
||||
if geminiError.Code != 0 || err != nil {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: geminiError.Message,
|
||||
Type: geminiError.Details[0].Type,
|
||||
Param: geminiError.Status,
|
||||
Code: geminiError.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
var geminiResponse EmbeddingResponse
|
||||
err = json.Unmarshal(body, &geminiResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiResponse, promptTokens, modelName)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@ -301,3 +361,21 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func embeddingResponseGemini2OpenAI(geminiResponse *EmbeddingResponse, promptTokens int, modelName string) *openai.EmbeddingResponse {
|
||||
data := make([]openai.EmbeddingResponseItem, 0, len(geminiResponse.Embeddings))
|
||||
|
||||
for index, embedding := range geminiResponse.Embeddings {
|
||||
data = append(data, openai.EmbeddingResponseItem{
|
||||
Object: "embedding",
|
||||
Embedding: embedding.Values,
|
||||
Index: index,
|
||||
})
|
||||
}
|
||||
return &openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: modelName,
|
||||
Usage: model.Usage{TotalTokens: promptTokens},
|
||||
}
|
||||
}
|
||||
|
@ -39,3 +39,32 @@ type ChatGenerationConfig struct {
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
Details []struct {
|
||||
Type string `json:"@type"`
|
||||
Reason string `json:"reason"`
|
||||
Domain string `json:"domain"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"details"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content ChatContent `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingMultiRequest struct {
|
||||
Requests []EmbeddingRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embeddings []EmbeddingData `json:"embeddings"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Values []float64 `json:"values"`
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user