fest: Add tooling to Gemini; Add OpenAI-like system prompt to Gemini

This commit is contained in:
David Zhuang 2023-12-16 02:50:46 -05:00
parent 6395c7377e
commit d027041f67
3 changed files with 46 additions and 18 deletions

View File

@ -215,4 +215,5 @@ var ChannelBaseURLs = []string{
"https://api.aiproxy.io", // 21 "https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22 "https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23 "https://hunyuan.cloud.tencent.com", //23
"", //24
} }

View File

@ -12,25 +12,31 @@ import (
type GeminiChatRequest struct { type GeminiChatRequest struct {
Contents []GeminiChatContents `json:"contents"` Contents []GeminiChatContents `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config"` GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
} }
type GeminiChatParts struct { type GeminiChatParts struct {
Text string `json:"text"` Text string `json:"text"`
} }
type GeminiChatContents struct { type GeminiChatContents struct {
Role string `json:"role"` Role string `json:"role"`
Parts GeminiChatParts `json:"parts"` Parts []GeminiChatParts `json:"parts"`
} }
type GeminiChatSafetySettings struct { type GeminiChatSafetySettings struct {
Category string `json:"category"` Category string `json:"category"`
Threshold string `json:"threshold"` Threshold string `json:"threshold"`
} }
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct { type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP"` TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK"` TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens"` MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
} }
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
@ -58,19 +64,44 @@ func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatReque
GenerationConfig: GeminiChatGenerationConfig{ GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
TopP: textRequest.TopP, TopP: textRequest.TopP,
TopK: textRequest.MaxTokens,
MaxOutputTokens: textRequest.MaxTokens, MaxOutputTokens: textRequest.MaxTokens,
}, },
Tools: []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
},
} }
systemPrompt := ""
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
content := GeminiChatContents{ content := GeminiChatContents{
Role: message.Role, Role: message.Role,
Parts: GeminiChatParts{ Parts: []GeminiChatParts{
Text: message.StringContent(), {
Text: message.StringContent(),
},
}, },
} }
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
systemPrompt = message.StringContent()
continue
}
if content.Role == "user" && systemPrompt != "" {
content.Parts = []GeminiChatParts{
{
Text: systemPrompt + "\n\nHuman: " + message.StringContent(),
},
}
systemPrompt = ""
}
geminiRequest.Contents = append(geminiRequest.Contents, content) geminiRequest.Contents = append(geminiRequest.Contents, content)
} }
return &geminiRequest return &geminiRequest
} }

View File

@ -119,6 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAIProxyLibrary apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent: case common.ChannelTypeTencent:
apiType = APITypeTencent apiType = APITypeTencent
case common.ChannelTypeGeminiChat:
apiType = APITypeGeminiChat
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
@ -179,15 +181,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey fullRequestURL += "?key=" + apiKey
case APITypeGeminiChat: case APITypeGeminiChat:
requestURLSuffix := "/v1beta/models/gemini-pro:generateContent"
switch textRequest.Model { switch textRequest.Model {
case "gemini-pro": case "gemini-pro":
requestURLSuffix = "/v1beta/models/gemini-pro:generateContent" fullRequestURL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
}
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURLSuffix)
} else {
fullRequestURL = fmt.Sprintf("https://generativelanguage.googleapis.com%s", requestURLSuffix)
} }
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")