chore: update implementation

This commit is contained in:
JustSong 2023-12-17 12:48:04 +08:00
parent 8e82edcd8c
commit 7ea90e708e
6 changed files with 80 additions and 59 deletions

View File

@ -187,7 +187,7 @@ const (
ChannelTypeAIProxyLibrary = 21 ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22 ChannelTypeFastGPT = 22
ChannelTypeTencent = 23 ChannelTypeTencent = 23
ChannelTypeGeminiChat = 24 ChannelTypeGemini = 24
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{

View File

@ -83,7 +83,7 @@ var ModelRatio = map[string]float64{
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1, "PaLM-2": 1,
"gemini-pro": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens

View File

@ -20,7 +20,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
switch channel.Type { switch channel.Type {
case common.ChannelTypePaLM: case common.ChannelTypePaLM:
fallthrough fallthrough
case common.ChannelTypeGeminiChat: case common.ChannelTypeGemini:
fallthrough fallthrough
case common.ChannelTypeAnthropic: case common.ChannelTypeAnthropic:
fallthrough fallthrough

View File

@ -11,25 +11,36 @@ import (
) )
type GeminiChatRequest struct { type GeminiChatRequest struct {
Contents []GeminiChatContents `json:"contents"` Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"` Tools []GeminiChatTools `json:"tools,omitempty"`
} }
type GeminiChatParts struct {
Text string `json:"text"` type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
} }
type GeminiChatContents struct {
Role string `json:"role"` type GeminiPart struct {
Parts []GeminiChatParts `json:"parts"` Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
} }
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct { type GeminiChatSafetySettings struct {
Category string `json:"category"` Category string `json:"category"`
Threshold string `json:"threshold"` Threshold string `json:"threshold"`
} }
type GeminiChatTools struct { type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"` FunctionDeclarations any `json:"functionDeclarations,omitempty"`
} }
type GeminiChatGenerationConfig struct { type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"` TopP float64 `json:"topP,omitempty"`
@ -40,43 +51,45 @@ type GeminiChatGenerationConfig struct {
} }
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatRequest { func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{ geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContents, 0, len(textRequest.Messages)), Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{ //SafetySettings: []GeminiChatSafetySettings{
{ // {
Category: "HARM_CATEGORY_HARASSMENT", // Category: "HARM_CATEGORY_HARASSMENT",
Threshold: "BLOCK_ONLY_HIGH", // Threshold: "BLOCK_ONLY_HIGH",
}, // },
{ // {
Category: "HARM_CATEGORY_HATE_SPEECH", // Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: "BLOCK_ONLY_HIGH", // Threshold: "BLOCK_ONLY_HIGH",
}, // },
{ // {
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: "BLOCK_ONLY_HIGH", // Threshold: "BLOCK_ONLY_HIGH",
}, // },
{ // {
Category: "HARM_CATEGORY_DANGEROUS_CONTENT", // Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: "BLOCK_ONLY_HIGH", // Threshold: "BLOCK_ONLY_HIGH",
}, // },
}, //},
GenerationConfig: GeminiChatGenerationConfig{ GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
TopP: textRequest.TopP, TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens, MaxOutputTokens: textRequest.MaxTokens,
}, },
Tools: []GeminiChatTools{ }
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{ {
FunctionDeclarations: textRequest.Functions, FunctionDeclarations: textRequest.Functions,
}, },
}, }
} }
shouldAddDummyModelMessage := false shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
content := GeminiChatContents{ content := GeminiChatContent{
Role: message.Role, Role: message.Role,
Parts: []GeminiChatParts{ Parts: []GeminiPart{
{ {
Text: message.StringContent(), Text: message.StringContent(),
}, },
@ -95,9 +108,9 @@ func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatReque
// If a system message is the last message, we need to add a dummy model message to make gemini happy // If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage { if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContents{ geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model", Role: "model",
Parts: []GeminiChatParts{ Parts: []GeminiPart{
{ {
Text: "ok", Text: "ok",
}, },
@ -122,15 +135,6 @@ type GeminiChatCandidate struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
} }
type GeminiChatContent struct {
Parts []GeminiChatPart `json:"parts"`
Role string `json:"role"`
}
type GeminiChatPart struct {
Text string `json:"text"`
}
type GeminiChatSafetyRating struct { type GeminiChatSafetyRating struct {
Category string `json:"category"` Category string `json:"category"`
Probability string `json:"probability"` Probability string `json:"probability"`
@ -142,6 +146,9 @@ type GeminiChatPromptFeedback struct {
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{ fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
} }
for i, candidate := range response.Candidates { for i, candidate := range response.Candidates {
@ -151,7 +158,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
Role: "assistant", Role: "assistant",
Content: candidate.Content.Parts[0].Text, Content: candidate.Content.Parts[0].Text,
}, },
FinishReason: "stop", FinishReason: stopFinishReason,
} }
fullTextResponse.Choices = append(fullTextResponse.Choices, choice) fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
} }
@ -160,7 +167,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice ChatCompletionsStreamResponseChoice
if len(geminiResponse.Candidates) > 0 { if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text
} }
choice.FinishReason = &stopFinishReason choice.FinishReason = &stopFinishReason
@ -200,7 +207,9 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Id = responseId fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime fullTextResponse.Created = createdTime
responseText = geminiResponse.Candidates[0].Content.Parts[0].Text if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
responseText += geminiResponse.Candidates[0].Content.Parts[0].Text
}
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())

View File

@ -27,7 +27,7 @@ const (
APITypeXunfei APITypeXunfei
APITypeAIProxyLibrary APITypeAIProxyLibrary
APITypeTencent APITypeTencent
APITypeGeminiChat APITypeGemini
) )
var httpClient *http.Client var httpClient *http.Client
@ -119,8 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAIProxyLibrary apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent: case common.ChannelTypeTencent:
apiType = APITypeTencent apiType = APITypeTencent
case common.ChannelTypeGeminiChat: case common.ChannelTypeGemini:
apiType = APITypeGeminiChat apiType = APITypeGemini
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
@ -180,11 +180,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey fullRequestURL += "?key=" + apiKey
case APITypeGeminiChat: case APITypeGemini:
switch textRequest.Model { requestBaseURL := "https://generativelanguage.googleapis.com"
case "gemini-pro": if baseURL != "" {
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" requestBaseURL = baseURL
} }
version := "v1"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
// actually gemini does not support stream, it's fake
//if textRequest.Stream {
// action = "streamGenerateContent"
//}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey fullRequestURL += "?key=" + apiKey
@ -285,8 +295,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
case APITypeGeminiChat: case APITypeGemini:
geminiChatRequest := requestOpenAI2GeminiChat(textRequest) geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest) jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
@ -385,7 +395,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
req.Header.Set("Authorization", apiKey) req.Header.Set("Authorization", apiKey)
case APITypePaLM: case APITypePaLM:
// do not set Authorization header // do not set Authorization header
case APITypeGeminiChat: case APITypeGemini:
// do not set Authorization header // do not set Authorization header
default: default:
req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Authorization", "Bearer "+apiKey)
@ -547,8 +557,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
return nil return nil
} }
case APITypeGeminiChat: case APITypeGemini:
if textRequest.Stream { // Gemini API does not support stream if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp) err, responseText := geminiChatStreamHandler(c, resp)
if err != nil { if err != nil {
return err return err

View File

@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) {
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei: case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary: case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other) c.Set("library_id", channel.Other)
case common.ChannelTypeAli: case common.ChannelTypeAli: