package google import ( "bufio" "encoding/json" "fmt" "io" "net/http" "one-api/common" "one-api/common/image" "one-api/relay/channel/openai" "one-api/relay/constant" "strings" "github.com/gin-gonic/gin" ) // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn const ( GeminiVisionMaxImageNum = 16 ) // Setting safety to the lowest possible values since Gemini is already powerless enough func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), SafetySettings: []GeminiChatSafetySettings{ { Category: "HARM_CATEGORY_HARASSMENT", Threshold: common.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: common.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: common.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: common.GeminiSafetySetting, }, }, GenerationConfig: GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, MaxOutputTokens: textRequest.MaxTokens, }, } if textRequest.Functions != nil { geminiRequest.Tools = []GeminiChatTools{ { FunctionDeclarations: textRequest.Functions, }, } } shouldAddDummyModelMessage := false for _, message := range textRequest.Messages { content := GeminiChatContent{ Role: message.Role, Parts: []GeminiPart{ { Text: message.StringContent(), }, }, } openaiContent := message.ParseContent() var parts []GeminiPart imageNum := 0 for _, part := range openaiContent { if part.Type == openai.ContentTypeText { parts = append(parts, GeminiPart{ Text: part.Text, }) } else if part.Type == openai.ContentTypeImageURL { imageNum += 1 if imageNum > GeminiVisionMaxImageNum { continue } mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ MimeType: mimeType, Data: data, }, }) } } content.Parts = parts // there's no assistant role in gemini and API shall vomit if Role is not user or model if content.Role == "assistant" { content.Role = "model" } // Converting system prompt to prompt from user for the same reason if content.Role == "system" { content.Role = "user" shouldAddDummyModelMessage = true } geminiRequest.Contents = append(geminiRequest.Contents, content) // If a system message is the last message, we need to add a dummy model message to make gemini happy if shouldAddDummyModelMessage { geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ Role: "model", Parts: []GeminiPart{ { Text: "Okay", }, }, }) shouldAddDummyModelMessage = false } } return &geminiRequest } type GeminiChatResponse struct { Candidates []GeminiChatCandidate `json:"candidates"` PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` } func (g *GeminiChatResponse) GetResponseText() string { if g == nil { return "" } if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { return g.Candidates[0].Content.Parts[0].Text } return "" } type GeminiChatCandidate struct { Content GeminiChatContent `json:"content"` FinishReason string `json:"finishReason"` Index int64 `json:"index"` SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` } type GeminiChatSafetyRating struct { Category string `json:"category"` Probability string `json:"probability"` } type GeminiChatPromptFeedback struct { SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` } func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion", Created: common.GetTimestamp(), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { choice := openai.TextResponseChoice{ Index: i, Message: openai.Message{ Role: "assistant", Content: "", }, FinishReason: constant.StopFinishReason, } if len(candidate.Content.Parts) > 0 { choice.Message.Content = candidate.Content.Parts[0].Text } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse } func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse { var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = geminiResponse.GetResponseText() choice.FinishReason = &constant.StopFinishReason var response openai.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "gemini" response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &response } func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" dataChan := make(chan string) stopChan := make(chan bool) scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } if i := strings.Index(string(data), "\n"); i >= 0 { return i + 1, data[0:i], nil } if atEOF { return len(data), data, nil } return 0, nil, nil }) go func() { for scanner.Scan() { data := scanner.Text() data = strings.TrimSpace(data) if !strings.HasPrefix(data, "\"text\": \"") { continue } data = strings.TrimPrefix(data, "\"text\": \"") data = strings.TrimSuffix(data, "\"") dataChan <- data } stopChan <- true }() common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: // this is used to prevent annoying \ related format bug data = fmt.Sprintf("{\"content\": \"%s\"}", data) type dummyStruct struct { Content string `json:"content"` } var dummy dummyStruct err := json.Unmarshal([]byte(data), &dummy) responseText += dummy.Content var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = dummy.Content response := openai.ChatCompletionsStreamResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "gemini-pro", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var geminiResponse GeminiChatResponse err = json.Unmarshal(responseBody, &geminiResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if len(geminiResponse.Candidates) == 0 { return &openai.ErrorWithStatusCode{ Error: openai.Error{ Message: "No candidates returned", Type: "server_error", Param: "", Code: 500, }, StatusCode: resp.StatusCode, }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) fullTextResponse.Model = model completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model) usage := openai.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) return nil, &usage }