diff --git a/common/image/image.go b/common/image/image.go index 93da6a06..a602936a 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -15,7 +15,22 @@ import ( _ "golang.org/x/image/webp" ) +func IsImageUrl(url string) (bool, error) { + resp, err := http.Head(url) + if err != nil { + return false, err + } + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { + return false, nil + } + return true, nil +} + func GetImageSizeFromUrl(url string) (width int, height int, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } resp, err := http.Get(url) if err != nil { return @@ -28,6 +43,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { return img.Width, img.Height, nil } +func GetImageFromUrl(url string) (mimeType string, data string, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + buffer := bytes.NewBuffer(nil) + _, err = buffer.ReadFrom(resp.Body) + if err != nil { + return + } + mimeType = resp.Header.Get("Content-Type") + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) + return +} + var ( reg = regexp.MustCompile(`data:image/([^;]+);base64,`) ) diff --git a/common/model-ratio.go b/common/model-ratio.go index d6b51f84..fa2adaa1 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{ "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 9ae40f5c..6a759b63 100644 --- a/controller/model.go +++ b/controller/model.go @@ -432,6 +432,15 @@ func init() { Root: "gemini-pro", Parent: nil, }, + { + Id: "gemini-pro-vision", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro-vision", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 523018de..ec55d4b6 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -7,11 +7,18 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/image" "strings" "github.com/gin-gonic/gin" ) +// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn + +const ( + GeminiVisionMaxImageNum = 16 +) + type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` @@ -97,6 +104,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { }, }, } + openaiContent := message.ParseContent() + var parts []GeminiPart + imageNum := 0 + for _, part := range openaiContent { + if part.Type == ContentTypeText { + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == ContentTypeImageURL { + imageNum += 1 + if imageNum > GeminiVisionMaxImageNum { + continue + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } + } + content.Parts = parts + // there's no assistant role in gemini and API shall vomit if Role is not user or model if content.Role == "assistant" { content.Role = "model" diff --git a/controller/relay-text.go b/controller/relay-text.go index c49a2abe..64338545 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -180,9 +180,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if baseURL != "" { fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey case APITypeGemini: requestBaseURL := "https://generativelanguage.googleapis.com" if baseURL != "" { @@ -197,9 +194,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { action = "streamGenerateContent" } fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -396,9 +390,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { case APITypeTencent: req.Header.Set("Authorization", apiKey) case APITypePaLM: - // do not set Authorization header + req.Header.Set("x-goog-api-key", apiKey) case APITypeGemini: - // do not set Authorization header + req.Header.Set("x-goog-api-key", apiKey) default: req.Header.Set("Authorization", "Bearer "+apiKey) } diff --git a/controller/relay.go b/controller/relay.go index b7906d08..e45fd3eb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -31,6 +31,22 @@ type ImageContent struct { ImageURL *ImageURL `json:"image_url,omitempty"` } +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) + +type OpenAIMessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +func (m Message) IsStringContent() bool { + _, ok := m.Content.(string) + return ok +} + func (m Message) StringContent() string { content, ok := m.Content.(string) if ok { @@ -44,7 +60,7 @@ func (m Message) StringContent() string { if !ok { continue } - if contentMap["type"] == "text" { + if contentMap["type"] == ContentTypeText { if subStr, ok := contentMap["text"].(string); ok { contentStr += subStr } @@ -55,6 +71,47 @@ func (m Message) StringContent() string { return "" } +func (m Message) ParseContent() []OpenAIMessageContent { + var contentList []OpenAIMessageContent + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeImageURL, + ImageURL: &ImageURL{ + Url: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + const ( RelayModeUnknown = iota RelayModeChatCompletions diff --git a/middleware/recover.go b/middleware/recover.go index c3a3d748..8338a514 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "runtime/debug" ) func RelayPanicRecover() gin.HandlerFunc { @@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc { defer func() { if err := recover(); err != nil { common.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index b1c7ae62..0d4e114d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -91,7 +91,7 @@ const EditChannel = () => { localModels = ['hunyuan']; break; case 24: - localModels = ['gemini-pro']; + localModels = ['gemini-pro', 'gemini-pro-vision']; break; } setInputs((inputs) => ({ ...inputs, models: localModels }));