chore: update implementation
This commit is contained in:
parent
8e82edcd8c
commit
7ea90e708e
@ -187,7 +187,7 @@ const (
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGeminiChat = 24
|
||||
ChannelTypeGemini = 24
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
|
@ -83,7 +83,7 @@ var ModelRatio = map[string]float64{
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
|
@ -20,7 +20,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
case common.ChannelTypeGeminiChat:
|
||||
case common.ChannelTypeGemini:
|
||||
fallthrough
|
||||
case common.ChannelTypeAnthropic:
|
||||
fallthrough
|
||||
|
@ -11,25 +11,36 @@ import (
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContents `json:"contents"`
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
type GeminiChatParts struct {
|
||||
Text string `json:"text"`
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
type GeminiChatContents struct {
|
||||
Role string `json:"role"`
|
||||
Parts []GeminiChatParts `json:"parts"`
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
@ -40,43 +51,45 @@ type GeminiChatGenerationConfig struct {
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContents, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
Threshold: "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
Threshold: "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
Threshold: "BLOCK_ONLY_HIGH",
|
||||
},
|
||||
},
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
//SafetySettings: []GeminiChatSafetySettings{
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_HARASSMENT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
//},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
Tools: []GeminiChatTools{
|
||||
}
|
||||
if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
content := GeminiChatContents{
|
||||
content := GeminiChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []GeminiChatParts{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
@ -95,9 +108,9 @@ func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatReque
|
||||
|
||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContents{
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []GeminiChatParts{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: "ok",
|
||||
},
|
||||
@ -122,15 +135,6 @@ type GeminiChatCandidate struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Parts []GeminiChatPart `json:"parts"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type GeminiChatPart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
@ -142,6 +146,9 @@ type GeminiChatPromptFeedback struct {
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
@ -151,7 +158,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
|
||||
Role: "assistant",
|
||||
Content: candidate.Content.Parts[0].Text,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
FinishReason: stopFinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
@ -160,7 +167,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
if len(geminiResponse.Candidates) > 0 {
|
||||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
||||
choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text
|
||||
}
|
||||
choice.FinishReason = &stopFinishReason
|
||||
@ -200,7 +207,9 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
|
||||
fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse.Id = responseId
|
||||
fullTextResponse.Created = createdTime
|
||||
responseText = geminiResponse.Candidates[0].Content.Parts[0].Text
|
||||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
||||
responseText += geminiResponse.Candidates[0].Content.Parts[0].Text
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
@ -27,7 +27,7 @@ const (
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGeminiChat
|
||||
APITypeGemini
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
@ -119,8 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGeminiChat:
|
||||
apiType = APITypeGeminiChat
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
@ -180,11 +180,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
case APITypeGeminiChat:
|
||||
switch textRequest.Model {
|
||||
case "gemini-pro":
|
||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
|
||||
case APITypeGemini:
|
||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
||||
if baseURL != "" {
|
||||
requestBaseURL = baseURL
|
||||
}
|
||||
version := "v1"
|
||||
if c.GetString("api_version") != "" {
|
||||
version = c.GetString("api_version")
|
||||
}
|
||||
action := "generateContent"
|
||||
// actually gemini does not support stream, it's fake
|
||||
//if textRequest.Stream {
|
||||
// action = "streamGenerateContent"
|
||||
//}
|
||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
@ -285,8 +295,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeGeminiChat:
|
||||
geminiChatRequest := requestOpenAI2GeminiChat(textRequest)
|
||||
case APITypeGemini:
|
||||
geminiChatRequest := requestOpenAI2Gemini(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
@ -385,7 +395,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case APITypePaLM:
|
||||
// do not set Authorization header
|
||||
case APITypeGeminiChat:
|
||||
case APITypeGemini:
|
||||
// do not set Authorization header
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
@ -547,8 +557,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeGeminiChat:
|
||||
if textRequest.Stream { // Gemini API does not support stream
|
||||
case APITypeGemini:
|
||||
if textRequest.Stream {
|
||||
err, responseText := geminiChatStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) {
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
|
Loading…
Reference in New Issue
Block a user