chore: update implementation

This commit is contained in:
JustSong 2023-12-17 12:48:04 +08:00
parent 8e82edcd8c
commit 7ea90e708e
6 changed files with 80 additions and 59 deletions

View File

@ -187,7 +187,7 @@ const (
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGeminiChat = 24
ChannelTypeGemini = 24
)
var ChannelBaseURLs = []string{

View File

@ -83,7 +83,7 @@ var ModelRatio = map[string]float64{
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens

View File

@ -20,7 +20,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
case common.ChannelTypeGeminiChat:
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeAnthropic:
fallthrough

View File

@ -11,25 +11,36 @@ import (
)
type GeminiChatRequest struct {
Contents []GeminiChatContents `json:"contents"`
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiChatParts struct {
Text string `json:"text"`
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiChatContents struct {
Role string `json:"role"`
Parts []GeminiChatParts `json:"parts"`
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
@ -40,43 +51,45 @@ type GeminiChatGenerationConfig struct {
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContents, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: "BLOCK_ONLY_HIGH",
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: "BLOCK_ONLY_HIGH",
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: "BLOCK_ONLY_HIGH",
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: "BLOCK_ONLY_HIGH",
},
},
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{
// {
// Category: "HARM_CATEGORY_HARASSMENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_HATE_SPEECH",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
//},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
Tools: []GeminiChatTools{
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContents{
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiChatParts{
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
@ -95,9 +108,9 @@ func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatReque
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContents{
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiChatParts{
Parts: []GeminiPart{
{
Text: "ok",
},
@ -122,15 +135,6 @@ type GeminiChatCandidate struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatContent struct {
Parts []GeminiChatPart `json:"parts"`
Role string `json:"role"`
}
type GeminiChatPart struct {
Text string `json:"text"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
@ -142,6 +146,9 @@ type GeminiChatPromptFeedback struct {
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
@ -151,7 +158,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
Role: "assistant",
Content: candidate.Content.Parts[0].Text,
},
FinishReason: "stop",
FinishReason: stopFinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
@ -160,7 +167,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
if len(geminiResponse.Candidates) > 0 {
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text
}
choice.FinishReason = &stopFinishReason
@ -200,7 +207,9 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
responseText = geminiResponse.Candidates[0].Content.Parts[0].Text
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
responseText += geminiResponse.Candidates[0].Content.Parts[0].Text
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())

View File

@ -27,7 +27,7 @@ const (
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGeminiChat
APITypeGemini
)
var httpClient *http.Client
@ -119,8 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGeminiChat:
apiType = APITypeGeminiChat
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@ -180,11 +180,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGeminiChat:
switch textRequest.Model {
case "gemini-pro":
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
// actually gemini does not support stream, it's fake
//if textRequest.Stream {
// action = "streamGenerateContent"
//}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
@ -285,8 +295,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGeminiChat:
geminiChatRequest := requestOpenAI2GeminiChat(textRequest)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
@ -385,7 +395,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
case APITypeGeminiChat:
case APITypeGemini:
// do not set Authorization header
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
@ -547,8 +557,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeGeminiChat:
if textRequest.Stream { // Gemini API does not support stream
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err

View File

@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) {
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeAli: