feat: add support for gemini embedding-001
This commit is contained in:
parent
6b27d6659a
commit
7038d2a71b
@ -127,6 +127,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"moonshot-v1-8k": 0.012 * RMB,
|
"moonshot-v1-8k": 0.012 * RMB,
|
||||||
"moonshot-v1-32k": 0.024 * RMB,
|
"moonshot-v1-32k": 0.024 * RMB,
|
||||||
"moonshot-v1-128k": 0.06 * RMB,
|
"moonshot-v1-128k": 0.06 * RMB,
|
||||||
|
"embedding-001": 0.01 * RMB,
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
channelhelper "github.com/songquanpeng/one-api/relay/channel"
|
channelhelper "github.com/songquanpeng/one-api/relay/channel"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||||
|
relaymode "github.com/songquanpeng/one-api/relay/constant"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"github.com/songquanpeng/one-api/relay/util"
|
"github.com/songquanpeng/one-api/relay/util"
|
||||||
"io"
|
"io"
|
||||||
@ -17,15 +18,19 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||||
|
fmt.Println(meta.APIVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||||
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if meta.IsStream {
|
|
||||||
|
if relaymode.RelayModeEmbeddings == meta.Mode {
|
||||||
|
action = "batchEmbedContents"
|
||||||
|
} else if meta.IsStream {
|
||||||
action = "streamGenerateContent"
|
action = "streamGenerateContent"
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,7 +44,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if relaymode.RelayModeEmbeddings == relayMode {
|
||||||
|
return ConvertEmbeddingRequest(*request), nil
|
||||||
|
} else {
|
||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||||
@ -47,7 +57,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
if meta.IsStream {
|
if relaymode.RelayModeEmbeddings == meta.Mode {
|
||||||
|
err, usage = EmbeddingHandler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||||
|
} else if meta.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = StreamHandler(c, resp)
|
err, responseText = StreamHandler(c, resp)
|
||||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||||
|
@ -3,4 +3,5 @@ package gemini
|
|||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemini-pro",
|
"gemini-pro",
|
||||||
"gemini-pro-vision",
|
"gemini-pro-vision",
|
||||||
|
"(Gemini)embedding-001",
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ const (
|
|||||||
VisionMaxImageNum = 16
|
VisionMaxImageNum = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// ConvertRequest Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||||
geminiRequest := ChatRequest{
|
geminiRequest := ChatRequest{
|
||||||
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
||||||
@ -122,6 +122,27 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
return &geminiRequest
|
return &geminiRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConvertEmbeddingRequest converts a GeneralOpenAIRequest to an EmbeddingMultiRequest
|
||||||
|
func ConvertEmbeddingRequest(textRequest model.GeneralOpenAIRequest) *EmbeddingMultiRequest {
|
||||||
|
inputs := textRequest.ParseInput()
|
||||||
|
requests := make([]EmbeddingRequest, 0, len(inputs))
|
||||||
|
for _, input := range inputs {
|
||||||
|
requests = append(requests, EmbeddingRequest{
|
||||||
|
Model: "models/embedding-001",
|
||||||
|
Content: ChatContent{
|
||||||
|
Parts: []Part{
|
||||||
|
{
|
||||||
|
Text: input,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &EmbeddingMultiRequest{
|
||||||
|
Requests: requests,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Candidates []ChatCandidate `json:"candidates"`
|
Candidates []ChatCandidate `json:"candidates"`
|
||||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||||
@ -258,6 +279,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbeddingHandler is a function that handles embedding requests
|
||||||
|
func EmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var geminiError Error
|
||||||
|
err = json.Unmarshal(body, &geminiError)
|
||||||
|
if geminiError.Code != 0 || err != nil {
|
||||||
|
return &model.ErrorWithStatusCode{
|
||||||
|
Error: model.Error{
|
||||||
|
Message: geminiError.Message,
|
||||||
|
Type: geminiError.Details[0].Type,
|
||||||
|
Param: geminiError.Status,
|
||||||
|
Code: geminiError.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
var geminiResponse EmbeddingResponse
|
||||||
|
err = json.Unmarshal(body, &geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiResponse, promptTokens, modelName)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -301,3 +361,21 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func embeddingResponseGemini2OpenAI(geminiResponse *EmbeddingResponse, promptTokens int, modelName string) *openai.EmbeddingResponse {
|
||||||
|
data := make([]openai.EmbeddingResponseItem, 0, len(geminiResponse.Embeddings))
|
||||||
|
|
||||||
|
for index, embedding := range geminiResponse.Embeddings {
|
||||||
|
data = append(data, openai.EmbeddingResponseItem{
|
||||||
|
Object: "embedding",
|
||||||
|
Embedding: embedding.Values,
|
||||||
|
Index: index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &openai.EmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: modelName,
|
||||||
|
Usage: model.Usage{TotalTokens: promptTokens},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -39,3 +39,32 @@ type ChatGenerationConfig struct {
|
|||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
StopSequences []string `json:"stopSequences,omitempty"`
|
StopSequences []string `json:"stopSequences,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Details []struct {
|
||||||
|
Type string `json:"@type"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
} `json:"details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Content ChatContent `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingMultiRequest struct {
|
||||||
|
Requests []EmbeddingRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Embeddings []EmbeddingData `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingData struct {
|
||||||
|
Values []float64 `json:"values"`
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user